From 1b5b28897c5884128b574cdcc9cbb4fad8c53c01 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 20 Dec 2016 17:32:33 +0000 Subject: [PATCH] store stacks as part of model --- src/backend/tensorflow/graph.jl | 2 +- src/backend/tensorflow/model.jl | 5 +++-- src/backend/tensorflow/recurrent.jl | 14 +++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index f5657cef..b096ad77 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -64,7 +64,7 @@ function tograph(model, args...) ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), params = ObjectIdDict(), stacks = Dict()) out = interp(ctx, model, map(constant, args)...) - return ctx[:params], out + return ctx[:params], ctx[:stacks], out end TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2] diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index a45cce4a..517598a2 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -2,14 +2,15 @@ type Model model::Any session::Session params::Dict{Flux.Param,Tensor} + stacks::Dict inputs::Vector{Tensor} output::Any end function makesession(model, inputs; session = Session(Graph())) - params, output = tograph(model, inputs...) + params, stacks, output = tograph(model, inputs...) run(session, initialize_all_variables()) - Model(model, session, params, inputs, output) + Model(model, session, params, stacks, inputs, output) end function makesession(model, n::Integer; session = Session(Graph())) diff --git a/src/backend/tensorflow/recurrent.jl b/src/backend/tensorflow/recurrent.jl index eebb3d99..18121038 100644 --- a/src/backend/tensorflow/recurrent.jl +++ b/src/backend/tensorflow/recurrent.jl @@ -11,22 +11,22 @@ function makesession(model::Flux.Unrolled) sess = Session(Graph()) input = placeholder(Float32) inputs = TensorFlow.unpack(input, num = model.steps, axis = 1) - params, outputs, instates, outstates = [], [], [], [] + params, stacks, outputs, instates, outstates = [], [], [], [], [] if model.stateful instates = [placeholder(Float32) for _ in model.state] - params, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...)) + params, stacks, (outstates, outputs) = tograph(model, cgroup(instates...), cgroup(inputs...)) else - params, outputs = tograph(model, cgroup(inputs...)) + params, stacks, outputs = tograph(model, cgroup(inputs...)) end output = TensorFlow.pack(outputs, axis = 1) run(sess, initialize_all_variables()) - sess, params, (instates, input), (outstates, output) + sess, params, stacks, (instates, input), (outstates, output) end function tf(model::Flux.Unrolled) - sess, params, (instates, input), (outstates, output) = makesession(model) + sess, params, stacks, (instates, input), (outstates, output) = makesession(model) SeqModel( - Model(model, sess, params, + Model(model, sess, params, stacks, [instates..., input], [outstates..., output]), model.state) end @@ -60,7 +60,7 @@ function Flux.train!(m::SeqModel, Xs, Ys; epoch = 1, η = 0.1, opt = () -> TensorFlow.train.GradientDescentOptimizer(η)) batchlen, seqlen = length(first(Xs)), length(first(Xs)[1]) state = batchone.(m.m.model.state) - sess, params, (instates, input), (outstates, output) = makesession(m.m.model) + sess, params, stacks, (instates, input), (outstates, output) = makesession(m.m.model) Y = placeholder(Float32) Loss = loss(Y, output)/batchlen/seqlen minimize_op = TensorFlow.train.minimize(opt(), Loss)