tf.model refactor
This commit is contained in:
parent
27aa2bf8d4
commit
740d868ef9
|
@ -42,15 +42,20 @@ Flux.shape(op::Op, d...) = op.shape(d...)
|
|||
# TODO: detect variable reuse
|
||||
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
|
||||
|
||||
function graph(model::Model, args...)
|
||||
g = Flux.graph(model)
|
||||
g ≠ nothing || error("No graph for $model")
|
||||
g = spliceinputs(g, map(constant, args)...) |> detuple
|
||||
postwalk(g) do v
|
||||
function graph(v::IVertex, args...)
|
||||
# TODO: check number of arguments
|
||||
v = spliceinputs(v, map(constant, args)...) |> detuple
|
||||
postwalk(v) do v
|
||||
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
|
||||
end |> value
|
||||
end
|
||||
|
||||
function graph(model::Flux.Model, args...)
|
||||
g = Flux.graph(model)
|
||||
g ≠ nothing || error("No graph for $model")
|
||||
graph(g, args...)
|
||||
end
|
||||
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
|
|
|
@ -1,23 +1,26 @@
|
|||
type Model
|
||||
model
|
||||
session::Session
|
||||
vars::Dict{Flux.Param,Tensor}
|
||||
inputs::Vector{Tensor}
|
||||
graph::Tensor
|
||||
grad::Tensor
|
||||
outputs::Vector{Tensor}
|
||||
gradients::Vector{Tensor}
|
||||
end
|
||||
|
||||
function tf(model)
|
||||
sess = Session(Graph())
|
||||
sess = Session()
|
||||
vars = Dict{Flux.Param,Tensor}()
|
||||
input = placeholder(Float32)
|
||||
g = graph(model, input)
|
||||
output = graph(model, input)
|
||||
run(sess, initialize_all_variables())
|
||||
Model(sess, [input], g, gradients(g, input))
|
||||
Model(model, sess, vars, [input], [output], [gradients(output, input)])
|
||||
end
|
||||
|
||||
batch(x) = Batch((x,))
|
||||
|
||||
function (m::Model)(args::Batch...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
run(m.session, m.graph, Dict(zip(m.inputs, args)))
|
||||
run(m.session, m.outputs[1], Dict(zip(m.inputs, args)))
|
||||
end
|
||||
|
||||
(m::Model)(args...) = m(map(batch, args)...)
|
||||
|
@ -25,7 +28,7 @@ end
|
|||
function Flux.back!(m::Model, Δ, args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
# TODO: keyword arguments to `gradients`
|
||||
run(m.session, m.grad, Dict(zip(m.inputs, args)))
|
||||
run(m.session, m.gradients[1], Dict(zip(m.inputs, args)))
|
||||
end
|
||||
|
||||
function Flux.update!(m::Model)
|
||||
|
@ -39,12 +42,12 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||
i = 0
|
||||
Y = placeholder(Float32)
|
||||
Loss = loss(m.graph, Y)
|
||||
Loss = loss(m.outputs[1], Y)
|
||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||
for e in 1:epoch
|
||||
info("Epoch $e\n")
|
||||
@progress for (x, y) in train
|
||||
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op),
|
||||
y, cur_loss, _ = run(m.session, vcat(m.outputs[1], Loss, minimize_op),
|
||||
Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
|
||||
if i % 5000 == 0
|
||||
@show y
|
||||
|
|
Loading…
Reference in New Issue