remove gradient support for now
This commit is contained in:
parent
f74ca7f7cf
commit
22568452f1
@ -4,7 +4,6 @@ type Model
|
||||
params::Dict{Flux.Param,Tensor}
|
||||
inputs::Vector{Tensor}
|
||||
output::Any
|
||||
gradients::Vector{Tensor}
|
||||
end
|
||||
|
||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||
@ -12,8 +11,7 @@ ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||
function tf(model)
|
||||
sess, params, input, output = makesession(model, 1)
|
||||
Model(model, sess, params,
|
||||
input, output,
|
||||
gradients(output, input))
|
||||
input, output)
|
||||
end
|
||||
|
||||
function batch(xs)
|
||||
@ -38,14 +36,10 @@ function (m::Model)(args...)
|
||||
ismultioutput(m) ? map(first, output) : first(output)
|
||||
end
|
||||
|
||||
function Flux.back!(m::Model, Δ, args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
# TODO: keyword arguments to `gradients`
|
||||
run(m.session, m.gradients[1], Dict(zip(m.inputs, args)))
|
||||
end
|
||||
|
||||
function Flux.update!(m::Model)
|
||||
error("update! is not yet supported on TensorFlow models")
|
||||
for f in :[back!, update!].args
|
||||
@eval function Flux.$f(m::Model, args...)
|
||||
error($(string(f)) * " is not yet supported on TensorFlow models")
|
||||
end
|
||||
end
|
||||
|
||||
import Juno: info
|
||||
|
@ -27,8 +27,7 @@ function tf(model::Flux.Unrolled)
|
||||
sess, params, (instates, input), (outstates, output) = makesession(model)
|
||||
SeqModel(
|
||||
Model(model, sess, params,
|
||||
[instates..., input], [outstates..., output],
|
||||
[placeholder(Float32)]),
|
||||
[instates..., input], [outstates..., output]),
|
||||
model.state)
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user