move batching logic

This commit is contained in:
Mike J Innes 2016-10-12 17:07:22 +01:00
parent 69551caadb
commit c9f9665e4e
3 changed files with 10 additions and 8 deletions

View File

@ -19,6 +19,7 @@ include("layers/shims.jl")
include("cost.jl")
include("activation.jl")
include("batching.jl")
include("backend/backend.jl")

View File

@ -48,13 +48,6 @@ graph(p::MaxPool, x) =
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
# Treat the first dimension as the batch index
# TODO: custom data type for this
batch(x) = reshape(x, (1,size(x)...))
batch(xs...) = vcat(map(batch, xs)...)
unbatch(xs) = reshape(xs, size(xs)[2:end])
type Model
session::Session
inputs::Vector{Tensor}
@ -72,7 +65,7 @@ end
function (m::Model)(args...)
@assert length(args) == length(m.inputs)
unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
Flux.unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
end
function Flux.back!(m::Model, Δ, args...)

8
src/batching.jl Normal file
View File

@ -0,0 +1,8 @@
export batch
# Treat the first dimension as the batch index
# TODO: custom data type for this
batch(x) = reshape(x, (1,size(x)...))
batch(xs...) = vcat(map(batch, xs)...)
unbatch(xs) = reshape(xs, size(xs)[2:end])