store stacks as part of model

This commit is contained in:
Mike J Innes 2016-12-20 17:32:33 +00:00
parent 0e08f175bc
commit 1b5b28897c
3 changed files with 11 additions and 10 deletions

View File

@ -64,7 +64,7 @@ function tograph(model, args...)
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
params = ObjectIdDict(), stacks = Dict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out
return ctx[:params], ctx[:stacks], out
end
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]

View File

@ -2,14 +2,15 @@ type Model
model::Any
session::Session
params::Dict{Flux.Param,Tensor}
stacks::Dict
inputs::Vector{Tensor}
output::Any
end
function makesession(model, inputs; session = Session(Graph()))
params, output = tograph(model, inputs...)
params, stacks, output = tograph(model, inputs...)
run(session, initialize_all_variables())
Model(model, session, params, inputs, output)
Model(model, session, params, stacks, inputs, output)
end
function makesession(model, n::Integer; session = Session(Graph()))

View File

@ -11,22 +11,22 @@ function makesession(model::Flux.Unrolled)
sess = Session(Graph())
input = placeholder(Float32)
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
params, outputs, instates, outstates = [], [], [], []
params, stacks, outputs, instates, outstates = [], [], [], [], []
if model.stateful
instates = [placeholder(Float32) for _ in model.state]
params, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
params, stacks, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
else
params, outputs = tograph(model, cgroup(inputs...))
params, stacks, outputs = tograph(model, cgroup(inputs...))
end
output = TensorFlow.pack(outputs, axis = 1)
run(sess, initialize_all_variables())
sess, params, (instates, input), (outstates, output)
sess, params, stacks, (instates, input), (outstates, output)
end
function tf(model::Flux.Unrolled)
sess, params, (instates, input), (outstates, output) = makesession(model)
sess, params, stacks, (instates, input), (outstates, output) = makesession(model)
SeqModel(
Model(model, sess, params,
Model(model, sess, params, stacks,
[instates..., input], [outstates..., output]),
model.state)
end
@ -60,7 +60,7 @@ function Flux.train!(m::SeqModel, Xs, Ys; epoch = 1, η = 0.1,
opt = () -> TensorFlow.train.GradientDescentOptimizer(η))
batchlen, seqlen = length(first(Xs)), length(first(Xs)[1])
state = batchone.(m.m.model.state)
sess, params, (instates, input), (outstates, output) = makesession(m.m.model)
sess, params, stacks, (instates, input), (outstates, output) = makesession(m.m.model)
Y = placeholder(Float32)
Loss = loss(Y, output)/batchlen/seqlen
minimize_op = TensorFlow.train.minimize(opt(), Loss)