From 2467ca4187b6ff293ab25e871ea093c4e0510535 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 1 May 2017 17:41:42 +0100 Subject: [PATCH] tuple inputs in tensorflow --- src/backend/tensorflow/model.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index 7ae3bed0..b43ce003 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -1,4 +1,4 @@ -using Flux: mapt +using Flux: mapt, collectt, shapecheckt struct Exec session::Session @@ -9,21 +9,20 @@ struct Exec end function makesession(model, inputs; session = Session(Graph())) + inputs = mapt(_ -> placeholder(Float32), inputs) params, stacks, output = tograph(model, inputs...) run(session, global_variables_initializer()) Exec(session, inputs, output, params, stacks) end -function makesession(model, n::Integer; session = Session(Graph())) - makesession(model, [placeholder(Float32) for _ = 1:n], session = session) -end - retuple(xs) = xs retuple(xs::AbstractArray{<:AbstractArray}) = (retuple.(xs)...,) +dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) + function (m::Exec)(args...) - @assert length(args) == length(m.input) - retuple(run(m.session, m.output, Dict(zip(m.input, args)))) + shapecheckt(m.input, args) + retuple(run(m.session, m.output, dictt(m.input, args))) end mutable struct Model @@ -35,8 +34,8 @@ end tf(model) = Model(model) function (m::Model)(args...) - args = mapt(x->convert.(Float32, x),args) - isdefined(m, :graph) || (m.exec = makesession(m.model, length(args))) + args = mapt(x->convert.(Float32, x), args) + isdefined(m, :graph) || (m.exec = makesession(m.model, args)) @tferr m.exec.stacks m.exec(args...) end