store stacks as part of model

This commit is contained in:
Mike J Innes 2016-12-20 17:32:33 +00:00
parent 0e08f175bc
commit 1b5b28897c
3 changed files with 11 additions and 10 deletions

View File

@ -64,7 +64,7 @@ function tograph(model, args...)
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
params = ObjectIdDict(), stacks = Dict()) params = ObjectIdDict(), stacks = Dict())
out = interp(ctx, model, map(constant, args)...) out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out return ctx[:params], ctx[:stacks], out
end end
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2] TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]

View File

@ -2,14 +2,15 @@ type Model
model::Any model::Any
session::Session session::Session
params::Dict{Flux.Param,Tensor} params::Dict{Flux.Param,Tensor}
stacks::Dict
inputs::Vector{Tensor} inputs::Vector{Tensor}
output::Any output::Any
end end
function makesession(model, inputs; session = Session(Graph())) function makesession(model, inputs; session = Session(Graph()))
params, output = tograph(model, inputs...) params, stacks, output = tograph(model, inputs...)
run(session, initialize_all_variables()) run(session, initialize_all_variables())
Model(model, session, params, inputs, output) Model(model, session, params, stacks, inputs, output)
end end
function makesession(model, n::Integer; session = Session(Graph())) function makesession(model, n::Integer; session = Session(Graph()))

View File

@ -11,22 +11,22 @@ function makesession(model::Flux.Unrolled)
sess = Session(Graph()) sess = Session(Graph())
input = placeholder(Float32) input = placeholder(Float32)
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1) inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
params, outputs, instates, outstates = [], [], [], [] params, stacks, outputs, instates, outstates = [], [], [], [], []
if model.stateful if model.stateful
instates = [placeholder(Float32) for _ in model.state] instates = [placeholder(Float32) for _ in model.state]
params, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...)) params, stacks, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
else else
params, outputs = tograph(model, cgroup(inputs...)) params, stacks, outputs = tograph(model, cgroup(inputs...))
end end
output = TensorFlow.pack(outputs, axis = 1) output = TensorFlow.pack(outputs, axis = 1)
run(sess, initialize_all_variables()) run(sess, initialize_all_variables())
sess, params, (instates, input), (outstates, output) sess, params, stacks, (instates, input), (outstates, output)
end end
function tf(model::Flux.Unrolled) function tf(model::Flux.Unrolled)
sess, params, (instates, input), (outstates, output) = makesession(model) sess, params, stacks, (instates, input), (outstates, output) = makesession(model)
SeqModel( SeqModel(
Model(model, sess, params, Model(model, sess, params, stacks,
[instates..., input], [outstates..., output]), [instates..., input], [outstates..., output]),
model.state) model.state)
end end
@ -60,7 +60,7 @@ function Flux.train!(m::SeqModel, Xs, Ys; epoch = 1, η = 0.1,
opt = () -> TensorFlow.train.GradientDescentOptimizer(η)) opt = () -> TensorFlow.train.GradientDescentOptimizer(η))
batchlen, seqlen = length(first(Xs)), length(first(Xs)[1]) batchlen, seqlen = length(first(Xs)), length(first(Xs)[1])
state = batchone.(m.m.model.state) state = batchone.(m.m.model.state)
sess, params, (instates, input), (outstates, output) = makesession(m.m.model) sess, params, stacks, (instates, input), (outstates, output) = makesession(m.m.model)
Y = placeholder(Float32) Y = placeholder(Float32)
Loss = loss(Y, output)/batchlen/seqlen Loss = loss(Y, output)/batchlen/seqlen
minimize_op = TensorFlow.train.minimize(opt(), Loss) minimize_op = TensorFlow.train.minimize(opt(), Loss)