tuple inputs in tensorflow
This commit is contained in:
parent
357f989de5
commit
2467ca4187
@ -1,4 +1,4 @@
|
||||
using Flux: mapt
|
||||
using Flux: mapt, collectt, shapecheckt
|
||||
|
||||
struct Exec
|
||||
session::Session
|
||||
@ -9,21 +9,20 @@ struct Exec
|
||||
end
|
||||
|
||||
function makesession(model, inputs; session = Session(Graph()))
|
||||
inputs = mapt(_ -> placeholder(Float32), inputs)
|
||||
params, stacks, output = tograph(model, inputs...)
|
||||
run(session, global_variables_initializer())
|
||||
Exec(session, inputs, output, params, stacks)
|
||||
end
|
||||
|
||||
function makesession(model, n::Integer; session = Session(Graph()))
|
||||
makesession(model, [placeholder(Float32) for _ = 1:n], session = session)
|
||||
end
|
||||
|
||||
retuple(xs) = xs
|
||||
retuple(xs::AbstractArray{<:AbstractArray}) = (retuple.(xs)...,)
|
||||
|
||||
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
||||
|
||||
function (m::Exec)(args...)
|
||||
@assert length(args) == length(m.input)
|
||||
retuple(run(m.session, m.output, Dict(zip(m.input, args))))
|
||||
shapecheckt(m.input, args)
|
||||
retuple(run(m.session, m.output, dictt(m.input, args)))
|
||||
end
|
||||
|
||||
mutable struct Model
|
||||
@ -35,8 +34,8 @@ end
|
||||
tf(model) = Model(model)
|
||||
|
||||
function (m::Model)(args...)
|
||||
args = mapt(x->convert.(Float32, x),args)
|
||||
isdefined(m, :graph) || (m.exec = makesession(m.model, length(args)))
|
||||
args = mapt(x->convert.(Float32, x), args)
|
||||
isdefined(m, :graph) || (m.exec = makesession(m.model, args))
|
||||
@tferr m.exec.stacks m.exec(args...)
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user