tf.model refactor

This commit is contained in:
Mike J Innes 2016-10-28 15:13:58 +01:00
parent 27aa2bf8d4
commit 740d868ef9
2 changed files with 22 additions and 14 deletions

View File

@ -42,15 +42,20 @@ Flux.shape(op::Op, d...) = op.shape(d...)
# TODO: detect variable reuse # TODO: detect variable reuse
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x) graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
function graph(model::Model, args...) function graph(v::IVertex, args...)
g = Flux.graph(model) # TODO: check number of arguments
g nothing || error("No graph for $model") v = spliceinputs(v, map(constant, args)...) |> detuple
g = spliceinputs(g, map(constant, args)...) |> detuple postwalk(v) do v
postwalk(g) do v
vertex(graph(cvalue(v), cvalue.(inputs(v))...)) vertex(graph(cvalue(v), cvalue.(inputs(v))...))
end |> value end |> value
end end
function graph(model::Flux.Model, args...)
g = Flux.graph(model)
g nothing || error("No graph for $model")
graph(g, args...)
end
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...) TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))

View File

@ -1,23 +1,26 @@
type Model type Model
model
session::Session session::Session
vars::Dict{Flux.Param,Tensor}
inputs::Vector{Tensor} inputs::Vector{Tensor}
graph::Tensor outputs::Vector{Tensor}
grad::Tensor gradients::Vector{Tensor}
end end
function tf(model) function tf(model)
sess = Session(Graph()) sess = Session()
vars = Dict{Flux.Param,Tensor}()
input = placeholder(Float32) input = placeholder(Float32)
g = graph(model, input) output = graph(model, input)
run(sess, initialize_all_variables()) run(sess, initialize_all_variables())
Model(sess, [input], g, gradients(g, input)) Model(model, sess, vars, [input], [output], [gradients(output, input)])
end end
batch(x) = Batch((x,)) batch(x) = Batch((x,))
function (m::Model)(args::Batch...) function (m::Model)(args::Batch...)
@assert length(args) == length(m.inputs) @assert length(args) == length(m.inputs)
run(m.session, m.graph, Dict(zip(m.inputs, args))) run(m.session, m.outputs[1], Dict(zip(m.inputs, args)))
end end
(m::Model)(args...) = m(map(batch, args)...) (m::Model)(args...) = m(map(batch, args)...)
@ -25,7 +28,7 @@ end
function Flux.back!(m::Model, Δ, args...) function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs) @assert length(args) == length(m.inputs)
# TODO: keyword arguments to `gradients` # TODO: keyword arguments to `gradients`
run(m.session, m.grad, Dict(zip(m.inputs, args))) run(m.session, m.gradients[1], Dict(zip(m.inputs, args)))
end end
function Flux.update!(m::Model) function Flux.update!(m::Model)
@ -39,12 +42,12 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
opt = TensorFlow.train.GradientDescentOptimizer(η)) opt = TensorFlow.train.GradientDescentOptimizer(η))
i = 0 i = 0
Y = placeholder(Float32) Y = placeholder(Float32)
Loss = loss(m.graph, Y) Loss = loss(m.outputs[1], Y)
minimize_op = TensorFlow.train.minimize(opt, Loss) minimize_op = TensorFlow.train.minimize(opt, Loss)
for e in 1:epoch for e in 1:epoch
info("Epoch $e\n") info("Epoch $e\n")
@progress for (x, y) in train @progress for (x, y) in train
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op), y, cur_loss, _ = run(m.session, vcat(m.outputs[1], Loss, minimize_op),
Dict(m.inputs[1]=>batch(x), Y=>batch(y))) Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
if i % 5000 == 0 if i % 5000 == 0
@show y @show y