handling of multiple outputs
This commit is contained in:
parent
1c6eaece5d
commit
e450a585b7
|
@ -3,17 +3,19 @@ type Model
|
|||
session::Session
|
||||
params::Dict{Flux.Param,Tensor}
|
||||
inputs::Vector{Tensor}
|
||||
outputs::Vector{Tensor}
|
||||
output
|
||||
gradients::Vector{Tensor}
|
||||
end
|
||||
|
||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||
|
||||
function tf(model)
|
||||
sess = Session(Graph())
|
||||
input = placeholder(Float32)
|
||||
params, output = tograph(model, input)
|
||||
run(sess, initialize_all_variables())
|
||||
Model(model, sess, params,
|
||||
[input], [output],
|
||||
[input], output,
|
||||
[gradients(output, input)])
|
||||
end
|
||||
|
||||
|
@ -28,10 +30,14 @@ end
|
|||
|
||||
function (m::Model)(args::Batch...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
batch(run(m.session, m.outputs[1], Dict(zip(m.inputs, args))))
|
||||
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||
ismultioutput(m) ? (batch.(output)...,) : batch(output)
|
||||
end
|
||||
|
||||
(m::Model)(args...) = first(m(map(batchone, args)...))
|
||||
function (m::Model)(args...)
|
||||
output = m(map(batchone, args)...)
|
||||
ismultioutput(m) ? map(first, output) : first(output)
|
||||
end
|
||||
|
||||
function Flux.back!(m::Model, Δ, args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
|
|
Loading…
Reference in New Issue