remove gradient support for now

This commit is contained in:
Mike J Innes 2016-12-20 16:37:43 +00:00
parent f74ca7f7cf
commit 22568452f1
2 changed files with 6 additions and 13 deletions

View File

@ -4,7 +4,6 @@ type Model
params::Dict{Flux.Param,Tensor}
inputs::Vector{Tensor}
output::Any
gradients::Vector{Tensor}
end
ismultioutput(m::Model) = !isa(m.output, Tensor)
@ -12,8 +11,7 @@ ismultioutput(m::Model) = !isa(m.output, Tensor)
function tf(model)
sess, params, input, output = makesession(model, 1)
Model(model, sess, params,
input, output,
gradients(output, input))
input, output)
end
function batch(xs)
@ -38,14 +36,10 @@ function (m::Model)(args...)
ismultioutput(m) ? map(first, output) : first(output)
end
function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs)
# TODO: keyword arguments to `gradients`
run(m.session, m.gradients[1], Dict(zip(m.inputs, args)))
end
function Flux.update!(m::Model)
error("update! is not yet supported on TensorFlow models")
for f in :[back!, update!].args
@eval function Flux.$f(m::Model, args...)
error($(string(f)) * " is not yet supported on TensorFlow models")
end
end
import Juno: info

View File

@ -27,8 +27,7 @@ function tf(model::Flux.Unrolled)
sess, params, (instates, input), (outstates, output) = makesession(model)
SeqModel(
Model(model, sess, params,
[instates..., input], [outstates..., output],
[placeholder(Float32)]),
[instates..., input], [outstates..., output]),
model.state)
end