store stacks as part of model
This commit is contained in:
parent
0e08f175bc
commit
1b5b28897c
@ -64,7 +64,7 @@ function tograph(model, args...)
|
||||
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
|
||||
params = ObjectIdDict(), stacks = Dict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], out
|
||||
return ctx[:params], ctx[:stacks], out
|
||||
end
|
||||
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
|
||||
|
@ -2,14 +2,15 @@ type Model
|
||||
model::Any
|
||||
session::Session
|
||||
params::Dict{Flux.Param,Tensor}
|
||||
stacks::Dict
|
||||
inputs::Vector{Tensor}
|
||||
output::Any
|
||||
end
|
||||
|
||||
function makesession(model, inputs; session = Session(Graph()))
|
||||
params, output = tograph(model, inputs...)
|
||||
params, stacks, output = tograph(model, inputs...)
|
||||
run(session, initialize_all_variables())
|
||||
Model(model, session, params, inputs, output)
|
||||
Model(model, session, params, stacks, inputs, output)
|
||||
end
|
||||
|
||||
function makesession(model, n::Integer; session = Session(Graph()))
|
||||
|
@ -11,22 +11,22 @@ function makesession(model::Flux.Unrolled)
|
||||
sess = Session(Graph())
|
||||
input = placeholder(Float32)
|
||||
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
||||
params, outputs, instates, outstates = [], [], [], []
|
||||
params, stacks, outputs, instates, outstates = [], [], [], [], []
|
||||
if model.stateful
|
||||
instates = [placeholder(Float32) for _ in model.state]
|
||||
params, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
|
||||
params, stacks, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
|
||||
else
|
||||
params, outputs = tograph(model, cgroup(inputs...))
|
||||
params, stacks, outputs = tograph(model, cgroup(inputs...))
|
||||
end
|
||||
output = TensorFlow.pack(outputs, axis = 1)
|
||||
run(sess, initialize_all_variables())
|
||||
sess, params, (instates, input), (outstates, output)
|
||||
sess, params, stacks, (instates, input), (outstates, output)
|
||||
end
|
||||
|
||||
function tf(model::Flux.Unrolled)
|
||||
sess, params, (instates, input), (outstates, output) = makesession(model)
|
||||
sess, params, stacks, (instates, input), (outstates, output) = makesession(model)
|
||||
SeqModel(
|
||||
Model(model, sess, params,
|
||||
Model(model, sess, params, stacks,
|
||||
[instates..., input], [outstates..., output]),
|
||||
model.state)
|
||||
end
|
||||
@ -60,7 +60,7 @@ function Flux.train!(m::SeqModel, Xs, Ys; epoch = 1, η = 0.1,
|
||||
opt = () -> TensorFlow.train.GradientDescentOptimizer(η))
|
||||
batchlen, seqlen = length(first(Xs)), length(first(Xs)[1])
|
||||
state = batchone.(m.m.model.state)
|
||||
sess, params, (instates, input), (outstates, output) = makesession(m.m.model)
|
||||
sess, params, stacks, (instates, input), (outstates, output) = makesession(m.m.model)
|
||||
Y = placeholder(Float32)
|
||||
Loss = loss(Y, output)/batchlen/seqlen
|
||||
minimize_op = TensorFlow.train.minimize(opt(), Loss)
|
||||
|
Loading…
Reference in New Issue
Block a user