tuple inputs in tensorflow

This commit is contained in:
Mike J Innes 2017-05-01 17:41:42 +01:00
parent 357f989de5
commit 2467ca4187

View File

@ -1,4 +1,4 @@
using Flux: mapt
using Flux: mapt, collectt, shapecheckt
struct Exec
session::Session
@ -9,21 +9,20 @@ struct Exec
end
function makesession(model, inputs; session = Session(Graph()))
inputs = mapt(_ -> placeholder(Float32), inputs)
params, stacks, output = tograph(model, inputs...)
run(session, global_variables_initializer())
Exec(session, inputs, output, params, stacks)
end
function makesession(model, n::Integer; session = Session(Graph()))
makesession(model, [placeholder(Float32) for _ = 1:n], session = session)
end
retuple(xs) = xs
retuple(xs::AbstractArray{<:AbstractArray}) = (retuple.(xs)...,)
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
function (m::Exec)(args...)
@assert length(args) == length(m.input)
retuple(run(m.session, m.output, Dict(zip(m.input, args))))
shapecheckt(m.input, args)
retuple(run(m.session, m.output, dictt(m.input, args)))
end
mutable struct Model
@ -35,8 +34,8 @@ end
tf(model) = Model(model)
function (m::Model)(args...)
args = mapt(x->convert.(Float32, x),args)
isdefined(m, :graph) || (m.exec = makesession(m.model, length(args)))
args = mapt(x->convert.(Float32, x), args)
isdefined(m, :graph) || (m.exec = makesession(m.model, args))
@tferr m.exec.stacks m.exec(args...)
end