batching refactor, nested batches
This commit is contained in:
parent
183c3b0680
commit
1847809e99
@ -18,9 +18,11 @@ include("layers/shape.jl")
|
|||||||
include("layers/chain.jl")
|
include("layers/chain.jl")
|
||||||
include("layers/shims.jl")
|
include("layers/shims.jl")
|
||||||
|
|
||||||
|
include("dims/catmat.jl")
|
||||||
|
include("dims/batching.jl")
|
||||||
|
|
||||||
include("cost.jl")
|
include("cost.jl")
|
||||||
include("activation.jl")
|
include("activation.jl")
|
||||||
include("batching.jl")
|
|
||||||
|
|
||||||
include("backend/backend.jl")
|
include("backend/backend.jl")
|
||||||
|
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
export batch, Batch
|
|
||||||
|
|
||||||
# TODO: support the Batch type only
|
|
||||||
batch(x) = reshape(x, (1,size(x)...))
|
|
||||||
batch(xs...) = vcat(map(batch, xs)...)
|
|
||||||
|
|
||||||
type Batch{T,T′} <: AbstractVector{T}
|
|
||||||
data::T′
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.size(b::Batch) = (size(b.data, 1),)
|
|
||||||
|
|
||||||
Base.getindex(b::Batch, i)::eltype(b) = slicedim(b.data, 1, i)
|
|
||||||
|
|
||||||
Base.setindex!(b::Batch, v, i) = b[i, :] = v
|
|
||||||
|
|
||||||
function (::Type{Batch{T}}){T}(xs)
|
|
||||||
x = first(xs)
|
|
||||||
batch = similar(x, length(xs), size(x)...)
|
|
||||||
for i = 1:length(xs)
|
|
||||||
@assert size(xs[i]) == size(x)
|
|
||||||
batch[i, :] = xs[i]
|
|
||||||
end
|
|
||||||
return Batch{T,typeof(batch)}(batch)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Batch(xs)
|
|
||||||
xs′ = promote(xs...)
|
|
||||||
Batch{typeof(xs′[1])}(xs′)
|
|
||||||
end
|
|
||||||
|
|
||||||
@render Juno.Inline b::Batch begin
|
|
||||||
Tree(Row(Text("Batch of "), eltype(b),
|
|
||||||
Juno.fade("[$(length(b))]")),
|
|
||||||
Juno.trim(collect(b)))
|
|
||||||
end
|
|
25
src/dims/batching.jl
Normal file
25
src/dims/batching.jl
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
export batch, Batch
|
||||||
|
|
||||||
|
# TODO: support the Batch type only
|
||||||
|
batch(x) = reshape(x, (1,size(x)...))
|
||||||
|
batch(xs...) = vcat(map(batch, xs)...)
|
||||||
|
|
||||||
|
immutable Batch{T,S} <: AbstractVector{T}
|
||||||
|
data::CatMat{T,S}
|
||||||
|
end
|
||||||
|
|
||||||
|
@forward Batch.data size, eltype, getindex, setindex!, rawbatch
|
||||||
|
|
||||||
|
Batch(xs) = Batch(CatMat(xs))
|
||||||
|
|
||||||
|
convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
||||||
|
Batch{T,S}(storage)
|
||||||
|
|
||||||
|
Media.render{T<:Batch}(i::Juno.Inline, b::Type{T}) =
|
||||||
|
render(i, Row(Juno.typ("Batch"), text"{", eltype(T), text"}"))
|
||||||
|
|
||||||
|
@render Juno.Inline b::Batch begin
|
||||||
|
Tree(Row(Text("Batch of "), eltype(b),
|
||||||
|
Juno.fade("[$(length(b))]")),
|
||||||
|
Juno.trim(collect(b)))
|
||||||
|
end
|
50
src/dims/catmat.jl
Normal file
50
src/dims/catmat.jl
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import Base: eltype, size, getindex, setindex!, convert
|
||||||
|
|
||||||
|
export CatMat
|
||||||
|
|
||||||
|
immutable CatMat{T,S} <: AbstractVector{T}
|
||||||
|
data::S
|
||||||
|
end
|
||||||
|
|
||||||
|
eltype{T}(::CatMat{T}) = T
|
||||||
|
|
||||||
|
size(b::CatMat) = (size(b.data, 1),)
|
||||||
|
|
||||||
|
getindex(b::CatMat, i)::eltype(b) = slicedim(b.data, 1, i)
|
||||||
|
|
||||||
|
setindex!(b::CatMat, v, i) = b[i, :] = v
|
||||||
|
|
||||||
|
allequal(xs) = all(x -> x == first(xs), xs)
|
||||||
|
|
||||||
|
function (::Type{CatMat{T,S}}){T,S}(xs, storage::S)
|
||||||
|
@assert @>> xs map(size) allequal
|
||||||
|
@assert size(storage) == (length(xs), size(first(xs))...)
|
||||||
|
for i = 1:length(xs)
|
||||||
|
storage[i, :] = xs[i]
|
||||||
|
end
|
||||||
|
return CatMat{T,S}(storage)
|
||||||
|
end
|
||||||
|
|
||||||
|
function (::Type{CatMat{T}}){T}(xs)
|
||||||
|
xs′ = map(rawbatch, xs)
|
||||||
|
storage = similar(first(xs′), (length(xs′), size(first(xs′))...))
|
||||||
|
CatMat{T,typeof(storage)}(xs′, storage)
|
||||||
|
end
|
||||||
|
|
||||||
|
function CatMat(xs)
|
||||||
|
xs = promote(xs...)
|
||||||
|
CatMat{eltype(xs)}(xs)
|
||||||
|
end
|
||||||
|
|
||||||
|
@render Juno.Inline b::CatMat begin
|
||||||
|
Tree(Row(Text("CatMat of "), eltype(b),
|
||||||
|
Juno.fade("[$(length(b))]")),
|
||||||
|
Juno.trim(collect(b)))
|
||||||
|
end
|
||||||
|
|
||||||
|
rawbatch(xs) = xs
|
||||||
|
|
||||||
|
rawbatch(xs::CatMat) = xs.data
|
||||||
|
|
||||||
|
convert{T,S}(::Type{CatMat{T,S}},storage::S) =
|
||||||
|
CatMat{T,S}(storage)
|
Loading…
Reference in New Issue
Block a user