explicit hidden state batching
This commit is contained in:
parent
3c068744d2
commit
2d90d04789
@ -26,12 +26,16 @@ function batch(xs)
|
|||||||
Batch{T,B}(xs)
|
Batch{T,B}(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::Model)(args::Batch...)
|
function runmodel(m, args...)
|
||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||||
ismultioutput(m) ? (batch.(output)...,) : batch(output)
|
ismultioutput(m) ? (batch.(output)...,) : batch(output)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function (m::Model)(args::Batch...)
|
||||||
|
runmodel(m, args...)
|
||||||
|
end
|
||||||
|
|
||||||
function (m::Model)(args...)
|
function (m::Model)(args...)
|
||||||
output = m(map(batchone, args)...)
|
output = m(map(batchone, args)...)
|
||||||
ismultioutput(m) ? map(first, output) : first(output)
|
ismultioutput(m) ? map(first, output) : first(output)
|
||||||
|
@ -24,7 +24,7 @@ function tf(model::Flux.Unrolled)
|
|||||||
Model(model, sess, params,
|
Model(model, sess, params,
|
||||||
[instates..., input], [outstates..., output],
|
[instates..., input], [outstates..., output],
|
||||||
[placeholder(Float32)]),
|
[placeholder(Float32)]),
|
||||||
batchone.(model.state))
|
model.state)
|
||||||
end
|
end
|
||||||
|
|
||||||
function batchseq(xs)
|
function batchseq(xs)
|
||||||
@ -35,11 +35,13 @@ function batchseq(xs)
|
|||||||
Batch{Seq{T,S},B}(xs)
|
Batch{Seq{T,S},B}(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
TensorFlow.get_tensors(x::Tuple) = TensorFlow.get_tensors(collect(x))
|
||||||
|
|
||||||
function (m::SeqModel)(x::BatchSeq)
|
function (m::SeqModel)(x::BatchSeq)
|
||||||
if isempty(m.state) || length(first(m.state)) ≠ length(x)
|
if isempty(m.state) || length(first(m.state)) ≠ length(x)
|
||||||
m.state = map(batchone, m.m.model.states)
|
m.state = m.m.model.state
|
||||||
end
|
end
|
||||||
output = m.m(m.state..., x)
|
output = runmodel(m.m, m.state..., x)
|
||||||
m.state, output = output[1:end-1], output[end]
|
m.state, output = output[1:end-1], output[end]
|
||||||
return batchseq(rawbatch(output))
|
return batchseq(rawbatch(output))
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user