makesession refactor
This commit is contained in:
parent
22568452f1
commit
0e08f175bc
@ -70,17 +70,3 @@ end
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
|
||||
function makesession(model, n)
|
||||
sess = Session(Graph())
|
||||
inputs = [placeholder(Float32) for _ = 1:n]
|
||||
params, output = tograph(model, inputs...)
|
||||
run(sess, initialize_all_variables())
|
||||
sess, params, inputs, output
|
||||
end
|
||||
|
||||
function storeparams!(sess, params)
|
||||
for (p, t) in params
|
||||
p.x = run(sess, t)
|
||||
end
|
||||
end
|
||||
|
@ -6,14 +6,28 @@ type Model
|
||||
output::Any
|
||||
end
|
||||
|
||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||
|
||||
function tf(model)
|
||||
sess, params, input, output = makesession(model, 1)
|
||||
Model(model, sess, params,
|
||||
input, output)
|
||||
function makesession(model, inputs; session = Session(Graph()))
|
||||
params, output = tograph(model, inputs...)
|
||||
run(session, initialize_all_variables())
|
||||
Model(model, session, params, inputs, output)
|
||||
end
|
||||
|
||||
function makesession(model, n::Integer; session = Session(Graph()))
|
||||
makesession(model, [placeholder(Float32) for _ = 1:n], session = session)
|
||||
end
|
||||
|
||||
tf(model) = makesession(model, 1)
|
||||
|
||||
function storeparams!(sess, params)
|
||||
for (p, t) in params
|
||||
p.x = run(sess, t)
|
||||
end
|
||||
end
|
||||
|
||||
storeparams!(m::Model) = storeparams!(m.session, m.params)
|
||||
|
||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||
|
||||
function batch(xs)
|
||||
dims = ndims(xs)-1
|
||||
T = Array{eltype(xs),dims}
|
||||
|
Loading…
Reference in New Issue
Block a user