get basic training working

This commit is contained in:
Mike J Innes 2016-09-29 20:50:43 +01:00
parent 9e35bcd4b7
commit a2aade718d
2 changed files with 29 additions and 8 deletions

View File

@ -1,6 +1,10 @@
module TF
using ..Flux, Flow, TensorFlow
import Juno: info
import Flux: accuracy
export tf
# Workaround for tensor display bug
using Juno
@ -12,10 +16,8 @@ cvalue(v::Vertex) = cvalue(value(v))
graph(x::Tensor) = x
matrixify(xs) = xs
matrixify(xs::Vector) = xs[:,1:1]
# TODO: detect variable reuse
graph{T<:AArray}(p::Flux.Param{T}) = Variable(matrixify(p.x))
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x')
function graph(model::Model, args...)
g = Flux.graph(model)
@ -28,8 +30,9 @@ function graph(model::Model, args...)
end |> value
end
graph(::typeof(*), args...) = *(args...)
graph(::typeof(*), args...) = *(reverse(args)...)
graph(::typeof(+), args...) = +(args...)
graph(::typeof(softmax), x) = nn.softmax(x)
type Model
session::Session
@ -41,7 +44,7 @@ end
Media.render(::Juno.Clipboard, ::Model) = "Flux.TF.Model()"
function tf(model)
sess = Session()
sess = Session(Graph())
input = placeholder(Float64)
g = graph(model, input)
run(sess, initialize_all_variables())
@ -50,7 +53,7 @@ end
function (m::Model)(args...)
@assert length(args) == length(m.inputs)
run(m.session, m.graph, Dict(zip(m.inputs, args)))
run(m.session, m.graph, Dict(zip(m.inputs, map(transpose, args))))'
end
function Flux.back!(m::Model, Δ, args...)
@ -63,4 +66,22 @@ function Flux.update!(m::Model)
error("update! is not yet supported on TensorFlow models")
end
function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
loss = (y, y) -> reduce_sum((y - y).^2)/2,
opt = TensorFlow.train.GradientDescentOptimizer(η))
i = 0
Y = placeholder(Float64)
Loss = loss(m.graph, Y)
minimize_op = TensorFlow.train.minimize(opt, Loss)
run(m.session, initialize_all_variables())
for e in 1:epoch
info("Epoch $e\n")
@progress for (x, y) in train
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op), Dict(m.inputs[1]=>x', Y=>y'))
i % 1000 == 0 && @show accuracy(m, test)
i += 1
end
end
end
end

View File

@ -5,7 +5,7 @@ const AArray = AbstractArray
onehot(label, labels) = [i == label for i in labels]
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
initn(dims...) = randn(dims...)/100
initn(dims...) = randn(dims...)/1000
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
i = 0
@ -24,7 +24,7 @@ function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
return m
end
function accuracy(m::Model, data)
function accuracy(m, data)
correct = 0
for (x, y) in data
onecold(m(x)) == onecold(y) && (correct += 1)