diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index f1b1c6cc..a0a3cebd 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -4,7 +4,6 @@ type Model params::Dict{Flux.Param,Tensor} inputs::Vector{Tensor} output::Any - gradients::Vector{Tensor} end ismultioutput(m::Model) = !isa(m.output, Tensor) @@ -12,8 +11,7 @@ ismultioutput(m::Model) = !isa(m.output, Tensor) function tf(model) sess, params, input, output = makesession(model, 1) Model(model, sess, params, - input, output, - gradients(output, input)) + input, output) end function batch(xs) @@ -38,14 +36,10 @@ function (m::Model)(args...) ismultioutput(m) ? map(first, output) : first(output) end -function Flux.back!(m::Model, Δ, args...) - @assert length(args) == length(m.inputs) - # TODO: keyword arguments to `gradients` - run(m.session, m.gradients[1], Dict(zip(m.inputs, args))) -end - -function Flux.update!(m::Model) - error("update! is not yet supported on TensorFlow models") +for f in :[back!, update!].args + @eval function Flux.$f(m::Model, args...) + error($(string(f)) * " is not yet supported on TensorFlow models") + end end import Juno: info diff --git a/src/backend/tensorflow/recurrent.jl b/src/backend/tensorflow/recurrent.jl index 5e53d9a8..eebb3d99 100644 --- a/src/backend/tensorflow/recurrent.jl +++ b/src/backend/tensorflow/recurrent.jl @@ -27,8 +27,7 @@ function tf(model::Flux.Unrolled) sess, params, (instates, input), (outstates, output) = makesession(model) SeqModel( Model(model, sess, params, - [instates..., input], [outstates..., output], - [placeholder(Float32)]), + [instates..., input], [outstates..., output]), model.state) end