diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index ca52bf15..37d92024 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -26,12 +26,16 @@ function batch(xs) Batch{T,B}(xs) end -function (m::Model)(args::Batch...) +function runmodel(m, args...) @assert length(args) == length(m.inputs) output = run(m.session, m.output, Dict(zip(m.inputs, args))) ismultioutput(m) ? (batch.(output)...,) : batch(output) end +function (m::Model)(args::Batch...) + runmodel(m, args...) +end + function (m::Model)(args...) output = m(map(batchone, args)...) ismultioutput(m) ? map(first, output) : first(output) diff --git a/src/backend/tensorflow/recurrent.jl b/src/backend/tensorflow/recurrent.jl index 2660196c..ac300f75 100644 --- a/src/backend/tensorflow/recurrent.jl +++ b/src/backend/tensorflow/recurrent.jl @@ -24,7 +24,7 @@ function tf(model::Flux.Unrolled) Model(model, sess, params, [instates..., input], [outstates..., output], [placeholder(Float32)]), - batchone.(model.state)) + model.state) end function batchseq(xs) @@ -35,11 +35,13 @@ function batchseq(xs) Batch{Seq{T,S},B}(xs) end +TensorFlow.get_tensors(x::Tuple) = TensorFlow.get_tensors(collect(x)) + function (m::SeqModel)(x::BatchSeq) if isempty(m.state) || length(first(m.state)) ≠ length(x) - m.state = map(batchone, m.m.model.states) + m.state = m.m.model.state end - output = m.m(m.state..., x) + output = runmodel(m.m, m.state..., x) m.state, output = output[1:end-1], output[end] return batchseq(rawbatch(output)) end