makesession refactor
This commit is contained in:
parent
22568452f1
commit
0e08f175bc
@ -70,17 +70,3 @@ end
|
|||||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
|
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
|
||||||
|
|
||||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||||
|
|
||||||
function makesession(model, n)
|
|
||||||
sess = Session(Graph())
|
|
||||||
inputs = [placeholder(Float32) for _ = 1:n]
|
|
||||||
params, output = tograph(model, inputs...)
|
|
||||||
run(sess, initialize_all_variables())
|
|
||||||
sess, params, inputs, output
|
|
||||||
end
|
|
||||||
|
|
||||||
function storeparams!(sess, params)
|
|
||||||
for (p, t) in params
|
|
||||||
p.x = run(sess, t)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
@ -6,14 +6,28 @@ type Model
|
|||||||
output::Any
|
output::Any
|
||||||
end
|
end
|
||||||
|
|
||||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
function makesession(model, inputs; session = Session(Graph()))
|
||||||
|
params, output = tograph(model, inputs...)
|
||||||
function tf(model)
|
run(session, initialize_all_variables())
|
||||||
sess, params, input, output = makesession(model, 1)
|
Model(model, session, params, inputs, output)
|
||||||
Model(model, sess, params,
|
|
||||||
input, output)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function makesession(model, n::Integer; session = Session(Graph()))
|
||||||
|
makesession(model, [placeholder(Float32) for _ = 1:n], session = session)
|
||||||
|
end
|
||||||
|
|
||||||
|
tf(model) = makesession(model, 1)
|
||||||
|
|
||||||
|
function storeparams!(sess, params)
|
||||||
|
for (p, t) in params
|
||||||
|
p.x = run(sess, t)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
storeparams!(m::Model) = storeparams!(m.session, m.params)
|
||||||
|
|
||||||
|
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||||
|
|
||||||
function batch(xs)
|
function batch(xs)
|
||||||
dims = ndims(xs)-1
|
dims = ndims(xs)-1
|
||||||
T = Array{eltype(xs),dims}
|
T = Array{eltype(xs),dims}
|
||||||
|
Loading…
Reference in New Issue
Block a user