From 0e08f175bc32bb500647846c18eb1eb34e54847e Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 20 Dec 2016 17:18:40 +0000 Subject: [PATCH] makesession refactor --- src/backend/tensorflow/graph.jl | 14 -------------- src/backend/tensorflow/model.jl | 26 ++++++++++++++++++++------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index e50efca4..f5657cef 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -70,17 +70,3 @@ end TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2] RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) - -function makesession(model, n) - sess = Session(Graph()) - inputs = [placeholder(Float32) for _ = 1:n] - params, output = tograph(model, inputs...) - run(sess, initialize_all_variables()) - sess, params, inputs, output -end - -function storeparams!(sess, params) - for (p, t) in params - p.x = run(sess, t) - end -end diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index a0a3cebd..a45cce4a 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -6,14 +6,28 @@ type Model output::Any end -ismultioutput(m::Model) = !isa(m.output, Tensor) - -function tf(model) - sess, params, input, output = makesession(model, 1) - Model(model, sess, params, - input, output) +function makesession(model, inputs; session = Session(Graph())) + params, output = tograph(model, inputs...) + run(session, initialize_all_variables()) + Model(model, session, params, inputs, output) end +function makesession(model, n::Integer; session = Session(Graph())) + makesession(model, [placeholder(Float32) for _ = 1:n], session = session) +end + +tf(model) = makesession(model, 1) + +function storeparams!(sess, params) + for (p, t) in params + p.x = run(sess, t) + end +end + +storeparams!(m::Model) = storeparams!(m.session, m.params) + +ismultioutput(m::Model) = !isa(m.output, Tensor) + function batch(xs) dims = ndims(xs)-1 T = Array{eltype(xs),dims}