Flux.jl/src/backend/tensorflow/model.jl
Mike J Innes 0ad569596b tf reorg
2016-10-26 14:25:10 +01:00

57 lines
1.4 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

type Model
session::Session
inputs::Vector{Tensor}
graph::Tensor
grad::Tensor
end
function tf(model)
sess = Session(Graph())
input = placeholder(Float32)
g = graph(model, input)
run(sess, initialize_all_variables())
Model(sess, [input], g, gradients(g, input))
end
batch(x) = Batch((x,))
function (m::Model)(args::Batch...)
@assert length(args) == length(m.inputs)
run(m.session, m.graph, Dict(zip(m.inputs, args)))
end
(m::Model)(args...) = m(map(batch, args)...)
function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs)
# TODO: keyword arguments to `gradients`
run(m.session, m.grad, Dict(zip(m.inputs, args)))
end
function Flux.update!(m::Model)
error("update! is not yet supported on TensorFlow models")
end
import Juno: info
function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
loss = (y, y) -> reduce_sum((y - y).^2)/2,
opt = TensorFlow.train.GradientDescentOptimizer(η))
i = 0
Y = placeholder(Float32)
Loss = loss(m.graph, Y)
minimize_op = TensorFlow.train.minimize(opt, Loss)
for e in 1:epoch
info("Epoch $e\n")
@progress for (x, y) in train
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op),
Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
if i % 5000 == 0
@show y
@show accuracy(m, test)
end
i += 1
end
end
end