handling of multiple outputs
This commit is contained in:
parent
1c6eaece5d
commit
e450a585b7
@ -3,17 +3,19 @@ type Model
|
|||||||
session::Session
|
session::Session
|
||||||
params::Dict{Flux.Param,Tensor}
|
params::Dict{Flux.Param,Tensor}
|
||||||
inputs::Vector{Tensor}
|
inputs::Vector{Tensor}
|
||||||
outputs::Vector{Tensor}
|
output
|
||||||
gradients::Vector{Tensor}
|
gradients::Vector{Tensor}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||||
|
|
||||||
function tf(model)
|
function tf(model)
|
||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
input = placeholder(Float32)
|
input = placeholder(Float32)
|
||||||
params, output = tograph(model, input)
|
params, output = tograph(model, input)
|
||||||
run(sess, initialize_all_variables())
|
run(sess, initialize_all_variables())
|
||||||
Model(model, sess, params,
|
Model(model, sess, params,
|
||||||
[input], [output],
|
[input], output,
|
||||||
[gradients(output, input)])
|
[gradients(output, input)])
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -28,10 +30,14 @@ end
|
|||||||
|
|
||||||
function (m::Model)(args::Batch...)
|
function (m::Model)(args::Batch...)
|
||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
batch(run(m.session, m.outputs[1], Dict(zip(m.inputs, args))))
|
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||||
|
ismultioutput(m) ? (batch.(output)...,) : batch(output)
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::Model)(args...) = first(m(map(batchone, args)...))
|
function (m::Model)(args...)
|
||||||
|
output = m(map(batchone, args)...)
|
||||||
|
ismultioutput(m) ? map(first, output) : first(output)
|
||||||
|
end
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, args...)
|
function Flux.back!(m::Model, Δ, args...)
|
||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user