handling of multiple outputs

This commit is contained in:
Mike J Innes 2016-10-28 20:50:27 +01:00
parent 1c6eaece5d
commit e450a585b7

View File

@ -3,17 +3,19 @@ type Model
session::Session session::Session
params::Dict{Flux.Param,Tensor} params::Dict{Flux.Param,Tensor}
inputs::Vector{Tensor} inputs::Vector{Tensor}
outputs::Vector{Tensor} output
gradients::Vector{Tensor} gradients::Vector{Tensor}
end end
ismultioutput(m::Model) = !isa(m.output, Tensor)
function tf(model) function tf(model)
sess = Session(Graph()) sess = Session(Graph())
input = placeholder(Float32) input = placeholder(Float32)
params, output = tograph(model, input) params, output = tograph(model, input)
run(sess, initialize_all_variables()) run(sess, initialize_all_variables())
Model(model, sess, params, Model(model, sess, params,
[input], [output], [input], output,
[gradients(output, input)]) [gradients(output, input)])
end end
@ -28,10 +30,14 @@ end
function (m::Model)(args::Batch...) function (m::Model)(args::Batch...)
@assert length(args) == length(m.inputs) @assert length(args) == length(m.inputs)
batch(run(m.session, m.outputs[1], Dict(zip(m.inputs, args)))) output = run(m.session, m.output, Dict(zip(m.inputs, args)))
ismultioutput(m) ? (batch.(output)...,) : batch(output)
end end
(m::Model)(args...) = first(m(map(batchone, args)...)) function (m::Model)(args...)
output = m(map(batchone, args)...)
ismultioutput(m) ? map(first, output) : first(output)
end
function Flux.back!(m::Model, Δ, args...) function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs) @assert length(args) == length(m.inputs)