use new batching approach in TensorFlow

This commit is contained in:
Mike J Innes 2016-10-25 16:21:17 +01:00
parent 46550e4863
commit a06145a145
2 changed files with 10 additions and 7 deletions

View File

@ -2,6 +2,7 @@ module TF
using ..Flux, Flow, TensorFlow, Juno
import Flux: accuracy
import TensorFlow: RawTensor
import Juno: info
export tf
@ -63,11 +64,17 @@ function tf(model)
Model(sess, [input], g, gradients(g, input))
end
function (m::Model)(args...)
batch(x) = Batch((x,))
RawTensor(data::Batch) = RawTensor(rawbatch(data))
function (m::Model)(args::Batch...)
@assert length(args) == length(m.inputs)
Flux.unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
run(m.session, m.graph, Dict(zip(m.inputs, args)))
end
(m::Model)(args...) = m(map(batch, args)...)
function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs)
# TODO: keyword arguments to `gradients`

View File

@ -1,8 +1,4 @@
export batch, Batch
# TODO: support the Batch type only
batch(x) = reshape(x, (1,size(x)...))
batch(xs...) = vcat(map(batch, xs)...)
export Batch
immutable Batch{T,S} <: AbstractVector{T}
data::CatMat{T,S}