use new batching approach in TensorFlow

This commit is contained in:
Mike J Innes 2016-10-25 16:21:17 +01:00
parent 46550e4863
commit a06145a145
2 changed files with 10 additions and 7 deletions

View File

@ -2,6 +2,7 @@ module TF
using ..Flux, Flow, TensorFlow, Juno using ..Flux, Flow, TensorFlow, Juno
import Flux: accuracy import Flux: accuracy
import TensorFlow: RawTensor
import Juno: info import Juno: info
export tf export tf
@ -63,11 +64,17 @@ function tf(model)
Model(sess, [input], g, gradients(g, input)) Model(sess, [input], g, gradients(g, input))
end end
function (m::Model)(args...) batch(x) = Batch((x,))
RawTensor(data::Batch) = RawTensor(rawbatch(data))
function (m::Model)(args::Batch...)
@assert length(args) == length(m.inputs) @assert length(args) == length(m.inputs)
Flux.unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args))))) run(m.session, m.graph, Dict(zip(m.inputs, args)))
end end
(m::Model)(args...) = m(map(batch, args)...)
function Flux.back!(m::Model, Δ, args...) function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs) @assert length(args) == length(m.inputs)
# TODO: keyword arguments to `gradients` # TODO: keyword arguments to `gradients`

View File

@ -1,8 +1,4 @@
export batch, Batch export Batch
# TODO: support the Batch type only
batch(x) = reshape(x, (1,size(x)...))
batch(xs...) = vcat(map(batch, xs)...)
immutable Batch{T,S} <: AbstractVector{T} immutable Batch{T,S} <: AbstractVector{T}
data::CatMat{T,S} data::CatMat{T,S}