move batching logic
This commit is contained in:
parent
69551caadb
commit
c9f9665e4e
@ -19,6 +19,7 @@ include("layers/shims.jl")
|
|||||||
|
|
||||||
include("cost.jl")
|
include("cost.jl")
|
||||||
include("activation.jl")
|
include("activation.jl")
|
||||||
|
include("batching.jl")
|
||||||
|
|
||||||
include("backend/backend.jl")
|
include("backend/backend.jl")
|
||||||
|
|
||||||
|
@ -48,13 +48,6 @@ graph(p::MaxPool, x) =
|
|||||||
|
|
||||||
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
|
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
|
||||||
|
|
||||||
# Treat the first dimension as the batch index
|
|
||||||
# TODO: custom data type for this
|
|
||||||
batch(x) = reshape(x, (1,size(x)...))
|
|
||||||
batch(xs...) = vcat(map(batch, xs)...)
|
|
||||||
|
|
||||||
unbatch(xs) = reshape(xs, size(xs)[2:end])
|
|
||||||
|
|
||||||
type Model
|
type Model
|
||||||
session::Session
|
session::Session
|
||||||
inputs::Vector{Tensor}
|
inputs::Vector{Tensor}
|
||||||
@ -72,7 +65,7 @@ end
|
|||||||
|
|
||||||
function (m::Model)(args...)
|
function (m::Model)(args...)
|
||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
|
Flux.unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, args...)
|
function Flux.back!(m::Model, Δ, args...)
|
||||||
|
8
src/batching.jl
Normal file
8
src/batching.jl
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
export batch
|
||||||
|
|
||||||
|
# Treat the first dimension as the batch index
|
||||||
|
# TODO: custom data type for this
|
||||||
|
batch(x) = reshape(x, (1,size(x)...))
|
||||||
|
batch(xs...) = vcat(map(batch, xs)...)
|
||||||
|
|
||||||
|
unbatch(xs) = reshape(xs, size(xs)[2:end])
|
Loading…
Reference in New Issue
Block a user