generic and consistent conversions
This commit is contained in:
parent
f0880f89cc
commit
13f4af2492
@ -3,9 +3,7 @@ module Batches
|
|||||||
using Juno, Lazy
|
using Juno, Lazy
|
||||||
using Juno: Tree, Row
|
using Juno: Tree, Row
|
||||||
|
|
||||||
export Storage, rawbatch,
|
export Batch, Batched, Seq, rawbatch, batchone
|
||||||
Batch, Batched, batchone, tobatch, rebatch,
|
|
||||||
Seq, BatchSeq, rebatchseq
|
|
||||||
|
|
||||||
include("catmat.jl")
|
include("catmat.jl")
|
||||||
include("batch.jl")
|
include("batch.jl")
|
||||||
|
@ -6,14 +6,6 @@ end
|
|||||||
|
|
||||||
Batch(xs) = Batch(Storage(xs))
|
Batch(xs) = Batch(Storage(xs))
|
||||||
|
|
||||||
# TODO: figure out how to express this as a generic convert
|
|
||||||
function rebatch(xs)
|
|
||||||
dims = ndims(xs)-1
|
|
||||||
T = Array{eltype(xs),dims}
|
|
||||||
B = Array{eltype(xs),dims+1}
|
|
||||||
Batch{T,B}(xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
convertel(T::Type, xs::Batch) =
|
convertel(T::Type, xs::Batch) =
|
||||||
eltype(eltype(xs)) isa T ? xs :
|
eltype(eltype(xs)) isa T ? xs :
|
||||||
Batch(map(x->convertel(T, x), xs))
|
Batch(map(x->convertel(T, x), xs))
|
||||||
@ -31,13 +23,3 @@ struct Seq{T,S} <: BatchLike{T,S}
|
|||||||
end
|
end
|
||||||
|
|
||||||
Seq(xs) = Seq(Storage(xs))
|
Seq(xs) = Seq(Storage(xs))
|
||||||
|
|
||||||
BatchSeq{T<:Seq} = Batch{T}
|
|
||||||
|
|
||||||
function rebatchseq(xs)
|
|
||||||
dims = ndims(xs)-2
|
|
||||||
T = Array{eltype(xs),dims}
|
|
||||||
S = Array{eltype(xs),dims+1}
|
|
||||||
B = Array{eltype(xs),dims+2}
|
|
||||||
Batch{Seq{T,S},B}(xs)
|
|
||||||
end
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import Base: eltype, size, getindex, setindex!, convert
|
import Base: eltype, size, getindex, setindex!, convert, typename
|
||||||
|
|
||||||
rawbatch(xs) = xs
|
rawbatch(xs) = xs
|
||||||
|
|
||||||
|
# Generic methods
|
||||||
|
|
||||||
abstract type BatchLike{T,S} <: AbstractVector{T}
|
abstract type BatchLike{T,S} <: AbstractVector{T}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -21,17 +23,17 @@ function setindex!(b::BatchLike, xs, ::Colon)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
typename(b::Type) = b
|
typerender(B::Type) = B
|
||||||
typename(b::Type{<:BatchLike}) =
|
typerender(B::Type{<:BatchLike}) =
|
||||||
Row(Juno.typ("$(b.name.name)"), text"{", typename(eltype(b)), text"}")
|
Row(Juno.typ("$(typename(B).name)"), text"{", typerender(eltype(B)), text"}")
|
||||||
|
|
||||||
@render Juno.Inline b::BatchLike begin
|
@render Juno.Inline b::BatchLike begin
|
||||||
Tree(Row(typename(typeof(b)),
|
Tree(Row(typerender(typeof(b)),
|
||||||
Juno.fade("[$(length(b))]")),
|
Juno.fade("[$(length(b))]")),
|
||||||
Juno.trim(collect(b)))
|
Juno.trim(collect(b)))
|
||||||
end
|
end
|
||||||
|
|
||||||
convert{T,S}(B::Type{<:BatchLike{T,S}},storage::S) = B(storage)
|
# Concrete storage
|
||||||
|
|
||||||
struct Storage{T,S} <: BatchLike{T,S}
|
struct Storage{T,S} <: BatchLike{T,S}
|
||||||
data::S
|
data::S
|
||||||
@ -39,7 +41,7 @@ end
|
|||||||
|
|
||||||
allequal(xs) = all(x -> x == first(xs), xs)
|
allequal(xs) = all(x -> x == first(xs), xs)
|
||||||
|
|
||||||
function (::Type{Storage{T,S}}){T,S}(xs, storage::S)
|
function Storage{T,S}(xs, storage::S) where {T, S}
|
||||||
@assert allequal(map(size, xs))
|
@assert allequal(map(size, xs))
|
||||||
@assert size(storage) == (length(xs), size(first(xs))...)
|
@assert size(storage) == (length(xs), size(first(xs))...)
|
||||||
for i = 1:length(xs)
|
for i = 1:length(xs)
|
||||||
@ -48,13 +50,28 @@ function (::Type{Storage{T,S}}){T,S}(xs, storage::S)
|
|||||||
return Storage{T,S}(storage)
|
return Storage{T,S}(storage)
|
||||||
end
|
end
|
||||||
|
|
||||||
function (::Type{Storage{T}}){T}(xs)
|
function Storage{T}(xs) where T
|
||||||
xs′ = map(rawbatch, xs)
|
xs′ = map(rawbatch, xs)
|
||||||
storage = similar(first(xs′), (length(xs′), size(first(xs′))...))
|
storage = similar(first(xs′), (length(xs′), size(first(xs′))...))
|
||||||
Storage{T,typeof(storage)}(xs′, storage)
|
Storage{T,typeof(storage)}(xs′, storage)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Storage(xs)
|
Storage(xs) = Storage{eltype(xs)}(xs)
|
||||||
xs = promote(xs...)
|
|
||||||
Storage{eltype(xs)}(xs)
|
convert{T,S}(B::Type{<:BatchLike{T,S}}, data::S) = B(data)
|
||||||
end
|
|
||||||
|
# Horrible type hacks follow this point
|
||||||
|
|
||||||
|
deparam(T::Type) = typename(T).wrapper
|
||||||
|
|
||||||
|
dimless(T::Type{<:AbstractArray}) = ndims(T) == 1 ? eltype(T) : deparam(T){eltype(T),ndims(T)-1}
|
||||||
|
|
||||||
|
btype(B::Type{<:BatchLike}, S::Type{<:AbstractArray}) = B{dimless(S),S}
|
||||||
|
btype(B::Type{<:BatchLike{T}} where T, S::Type{<:AbstractArray}) = B{S}
|
||||||
|
btype(B::Type{<:BatchLike{<:BatchLike}}, S::Type{<:AbstractArray}) =
|
||||||
|
deparam(B){btype(eltype(B), dimless(S)),S}
|
||||||
|
|
||||||
|
convert{T<:BatchLike}(::Type{T}, xs::AbstractArray) =
|
||||||
|
convert(btype(T, typeof(xs)), xs)
|
||||||
|
|
||||||
|
convert{T<:BatchLike}(::Type{T}, x::T) = x
|
||||||
|
@ -48,7 +48,7 @@ end
|
|||||||
function Batched(itr, n::Integer)
|
function Batched(itr, n::Integer)
|
||||||
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
|
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
|
||||||
itr = StatefulIter(itr)
|
itr = StatefulIter(itr)
|
||||||
buf = rebatch(similar(eltype(itr)(), n, size(peek(itr))...))
|
buf = convert(Batch, similar(eltype(itr)(), n, size(peek(itr))...))
|
||||||
Batched(itr, buf)
|
Batched(itr, buf)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ end
|
|||||||
|
|
||||||
runseq(f, xs::Tuple...) = f(xs...)
|
runseq(f, xs::Tuple...) = f(xs...)
|
||||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
runseq(f, xs::Batch{<:Seq}...) = convert(Batch{Seq}, runseq(f, rawbatch.(xs)...))
|
||||||
runseq(f, xs) = runseq(f, (xs...,))
|
runseq(f, xs) = runseq(f, (xs...,))
|
||||||
|
|
||||||
function (m::SeqModel)(x)
|
function (m::SeqModel)(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user