handling of multiple outputs

This commit is contained in:
Mike J Innes 2016-10-28 20:50:27 +01:00
parent 1c6eaece5d
commit e450a585b7
1 changed files with 10 additions and 4 deletions

View File

@ -3,17 +3,19 @@ type Model
session::Session
params::Dict{Flux.Param,Tensor}
inputs::Vector{Tensor}
outputs::Vector{Tensor}
output
gradients::Vector{Tensor}
end
ismultioutput(m::Model) = !isa(m.output, Tensor)
function tf(model)
sess = Session(Graph())
input = placeholder(Float32)
params, output = tograph(model, input)
run(sess, initialize_all_variables())
Model(model, sess, params,
[input], [output],
[input], output,
[gradients(output, input)])
end
@ -28,10 +30,14 @@ end
function (m::Model)(args::Batch...)
@assert length(args) == length(m.inputs)
batch(run(m.session, m.outputs[1], Dict(zip(m.inputs, args))))
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
ismultioutput(m) ? (batch.(output)...,) : batch(output)
end
(m::Model)(args...) = first(m(map(batchone, args)...))
function (m::Model)(args...)
output = m(map(batchone, args)...)
ismultioutput(m) ? map(first, output) : first(output)
end
function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs)