diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index c8f61850..85459076 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -1,18 +1,21 @@ using Flux: mapt, collectt, shapecheckt struct Exec - session::Session - input::Any - output::Any - params::Dict{Flux.Param,Tensor} - stacks::Dict{Any,Any} + session ::Session + input ::Any + output ::Any + grads ::Any + params ::Dict{Flux.Param,Tensor} + stacks ::Dict{Any,Any} end function makesession(model, inputs; session = Session(Graph())) inputs = mapt(_ -> placeholder(Float32), inputs) params, stacks, output = tograph(model, inputs...) + # grads = gradients(output, [collectt(inputs)..., values(params)...]) + grads = placeholder(Float32) run(session, global_variables_initializer()) - Exec(session, inputs, output, params, stacks) + Exec(session, inputs, output, grads, params, stacks) end retuple(xs) = xs @@ -20,11 +23,37 @@ retuple(xs::AbstractArray{<:AbstractArray}) = (retuple.(xs)...,) dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) -function (m::Exec)(args...) +function params(m::Exec, args...) shapecheckt(m.input, args) idict = dictt(m.input, args) pdict = Dict(t => p.x for (p, t) in m.params) - retuple(run(m.session, m.output, merge(idict, pdict))) + merge(idict, pdict) +end + +function (m::Exec)(args...) + retuple(run(m.session, m.output, params(m, args...))) +end + +pullt!(_, xs) = shift!(xs) +pullt!(x::Tuple, xs) = map(x -> pullt!(x, xs), x) + +# TODO: gradients don't work yet +# `gradients` lacks support for `grad_y`s and multiple `y`s + +function Flux.back!(m::Exec, Δ, args...) + Δps = run(m.session, m.grads, params(m, args...)) + Δin = pullt!(m.input, Δps) + for (p, Δ) in zip(keys(m.params), Δps) + p.Δx .+= Δ + end + Δin +end + +function Flux.update!(m::Exec, η) + for p in keys(m.params) + update!(p, η) + end + return m end mutable struct Model @@ -41,11 +70,8 @@ function (m::Model)(args...) @tferr m.exec.stacks m.exec(args...) end -for f in :[back!, update!].args - @eval function Flux.$f(m::Model, args...) - error($(string(f)) * " is not yet supported on TensorFlow models") - end -end +Flux.back!(m::Model, Δ, args...) = back!(m.exec, Δ, args...) +Flux.update!(m::Model, η) = (update!(m.exec, η); m) # Recurrent Models