remove gradient support for now
This commit is contained in:
parent
f74ca7f7cf
commit
22568452f1
@ -4,7 +4,6 @@ type Model
|
|||||||
params::Dict{Flux.Param,Tensor}
|
params::Dict{Flux.Param,Tensor}
|
||||||
inputs::Vector{Tensor}
|
inputs::Vector{Tensor}
|
||||||
output::Any
|
output::Any
|
||||||
gradients::Vector{Tensor}
|
|
||||||
end
|
end
|
||||||
|
|
||||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||||
@ -12,8 +11,7 @@ ismultioutput(m::Model) = !isa(m.output, Tensor)
|
|||||||
function tf(model)
|
function tf(model)
|
||||||
sess, params, input, output = makesession(model, 1)
|
sess, params, input, output = makesession(model, 1)
|
||||||
Model(model, sess, params,
|
Model(model, sess, params,
|
||||||
input, output,
|
input, output)
|
||||||
gradients(output, input))
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function batch(xs)
|
function batch(xs)
|
||||||
@ -38,14 +36,10 @@ function (m::Model)(args...)
|
|||||||
ismultioutput(m) ? map(first, output) : first(output)
|
ismultioutput(m) ? map(first, output) : first(output)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, args...)
|
for f in :[back!, update!].args
|
||||||
@assert length(args) == length(m.inputs)
|
@eval function Flux.$f(m::Model, args...)
|
||||||
# TODO: keyword arguments to `gradients`
|
error($(string(f)) * " is not yet supported on TensorFlow models")
|
||||||
run(m.session, m.gradients[1], Dict(zip(m.inputs, args)))
|
end
|
||||||
end
|
|
||||||
|
|
||||||
function Flux.update!(m::Model)
|
|
||||||
error("update! is not yet supported on TensorFlow models")
|
|
||||||
end
|
end
|
||||||
|
|
||||||
import Juno: info
|
import Juno: info
|
||||||
|
@ -27,8 +27,7 @@ function tf(model::Flux.Unrolled)
|
|||||||
sess, params, (instates, input), (outstates, output) = makesession(model)
|
sess, params, (instates, input), (outstates, output) = makesession(model)
|
||||||
SeqModel(
|
SeqModel(
|
||||||
Model(model, sess, params,
|
Model(model, sess, params,
|
||||||
[instates..., input], [outstates..., output],
|
[instates..., input], [outstates..., output]),
|
||||||
[placeholder(Float32)]),
|
|
||||||
model.state)
|
model.state)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user