makesession refactor

This commit is contained in:
Mike J Innes 2016-12-20 17:18:40 +00:00
parent 22568452f1
commit 0e08f175bc
2 changed files with 20 additions and 20 deletions

View File

@ -70,17 +70,3 @@ end
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
function makesession(model, n)
sess = Session(Graph())
inputs = [placeholder(Float32) for _ = 1:n]
params, output = tograph(model, inputs...)
run(sess, initialize_all_variables())
sess, params, inputs, output
end
function storeparams!(sess, params)
for (p, t) in params
p.x = run(sess, t)
end
end

View File

@ -6,14 +6,28 @@ type Model
output::Any
end
ismultioutput(m::Model) = !isa(m.output, Tensor)
function tf(model)
sess, params, input, output = makesession(model, 1)
Model(model, sess, params,
input, output)
function makesession(model, inputs; session = Session(Graph()))
params, output = tograph(model, inputs...)
run(session, initialize_all_variables())
Model(model, session, params, inputs, output)
end
function makesession(model, n::Integer; session = Session(Graph()))
makesession(model, [placeholder(Float32) for _ = 1:n], session = session)
end
tf(model) = makesession(model, 1)
function storeparams!(sess, params)
for (p, t) in params
p.x = run(sess, t)
end
end
storeparams!(m::Model) = storeparams!(m.session, m.params)
ismultioutput(m::Model) = !isa(m.output, Tensor)
function batch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}