store stacks as part of model
This commit is contained in:
parent
0e08f175bc
commit
1b5b28897c
@ -64,7 +64,7 @@ function tograph(model, args...)
|
|||||||
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
|
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))),
|
||||||
params = ObjectIdDict(), stacks = Dict())
|
params = ObjectIdDict(), stacks = Dict())
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
return ctx[:params], out
|
return ctx[:params], ctx[:stacks], out
|
||||||
end
|
end
|
||||||
|
|
||||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
|
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
|
||||||
|
@ -2,14 +2,15 @@ type Model
|
|||||||
model::Any
|
model::Any
|
||||||
session::Session
|
session::Session
|
||||||
params::Dict{Flux.Param,Tensor}
|
params::Dict{Flux.Param,Tensor}
|
||||||
|
stacks::Dict
|
||||||
inputs::Vector{Tensor}
|
inputs::Vector{Tensor}
|
||||||
output::Any
|
output::Any
|
||||||
end
|
end
|
||||||
|
|
||||||
function makesession(model, inputs; session = Session(Graph()))
|
function makesession(model, inputs; session = Session(Graph()))
|
||||||
params, output = tograph(model, inputs...)
|
params, stacks, output = tograph(model, inputs...)
|
||||||
run(session, initialize_all_variables())
|
run(session, initialize_all_variables())
|
||||||
Model(model, session, params, inputs, output)
|
Model(model, session, params, stacks, inputs, output)
|
||||||
end
|
end
|
||||||
|
|
||||||
function makesession(model, n::Integer; session = Session(Graph()))
|
function makesession(model, n::Integer; session = Session(Graph()))
|
||||||
|
@ -11,22 +11,22 @@ function makesession(model::Flux.Unrolled)
|
|||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
input = placeholder(Float32)
|
input = placeholder(Float32)
|
||||||
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
inputs = TensorFlow.unpack(input, num = model.steps, axis = 1)
|
||||||
params, outputs, instates, outstates = [], [], [], []
|
params, stacks, outputs, instates, outstates = [], [], [], [], []
|
||||||
if model.stateful
|
if model.stateful
|
||||||
instates = [placeholder(Float32) for _ in model.state]
|
instates = [placeholder(Float32) for _ in model.state]
|
||||||
params, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
|
params, stacks, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...))
|
||||||
else
|
else
|
||||||
params, outputs = tograph(model, cgroup(inputs...))
|
params, stacks, outputs = tograph(model, cgroup(inputs...))
|
||||||
end
|
end
|
||||||
output = TensorFlow.pack(outputs, axis = 1)
|
output = TensorFlow.pack(outputs, axis = 1)
|
||||||
run(sess, initialize_all_variables())
|
run(sess, initialize_all_variables())
|
||||||
sess, params, (instates, input), (outstates, output)
|
sess, params, stacks, (instates, input), (outstates, output)
|
||||||
end
|
end
|
||||||
|
|
||||||
function tf(model::Flux.Unrolled)
|
function tf(model::Flux.Unrolled)
|
||||||
sess, params, (instates, input), (outstates, output) = makesession(model)
|
sess, params, stacks, (instates, input), (outstates, output) = makesession(model)
|
||||||
SeqModel(
|
SeqModel(
|
||||||
Model(model, sess, params,
|
Model(model, sess, params, stacks,
|
||||||
[instates..., input], [outstates..., output]),
|
[instates..., input], [outstates..., output]),
|
||||||
model.state)
|
model.state)
|
||||||
end
|
end
|
||||||
@ -60,7 +60,7 @@ function Flux.train!(m::SeqModel, Xs, Ys; epoch = 1, η = 0.1,
|
|||||||
opt = () -> TensorFlow.train.GradientDescentOptimizer(η))
|
opt = () -> TensorFlow.train.GradientDescentOptimizer(η))
|
||||||
batchlen, seqlen = length(first(Xs)), length(first(Xs)[1])
|
batchlen, seqlen = length(first(Xs)), length(first(Xs)[1])
|
||||||
state = batchone.(m.m.model.state)
|
state = batchone.(m.m.model.state)
|
||||||
sess, params, (instates, input), (outstates, output) = makesession(m.m.model)
|
sess, params, stacks, (instates, input), (outstates, output) = makesession(m.m.model)
|
||||||
Y = placeholder(Float32)
|
Y = placeholder(Float32)
|
||||||
Loss = loss(Y, output)/batchlen/seqlen
|
Loss = loss(Y, output)/batchlen/seqlen
|
||||||
minimize_op = TensorFlow.train.minimize(opt(), Loss)
|
minimize_op = TensorFlow.train.minimize(opt(), Loss)
|
||||||
|
Loading…
Reference in New Issue
Block a user