generic and consistent conversions
This commit is contained in:
parent
f0880f89cc
commit
13f4af2492
@ -3,9 +3,7 @@ module Batches
|
||||
using Juno, Lazy
|
||||
using Juno: Tree, Row
|
||||
|
||||
export Storage, rawbatch,
|
||||
Batch, Batched, batchone, tobatch, rebatch,
|
||||
Seq, BatchSeq, rebatchseq
|
||||
export Batch, Batched, Seq, rawbatch, batchone
|
||||
|
||||
include("catmat.jl")
|
||||
include("batch.jl")
|
||||
|
@ -6,14 +6,6 @@ end
|
||||
|
||||
Batch(xs) = Batch(Storage(xs))
|
||||
|
||||
# TODO: figure out how to express this as a generic convert
|
||||
function rebatch(xs)
|
||||
dims = ndims(xs)-1
|
||||
T = Array{eltype(xs),dims}
|
||||
B = Array{eltype(xs),dims+1}
|
||||
Batch{T,B}(xs)
|
||||
end
|
||||
|
||||
convertel(T::Type, xs::Batch) =
|
||||
eltype(eltype(xs)) isa T ? xs :
|
||||
Batch(map(x->convertel(T, x), xs))
|
||||
@ -31,13 +23,3 @@ struct Seq{T,S} <: BatchLike{T,S}
|
||||
end
|
||||
|
||||
Seq(xs) = Seq(Storage(xs))
|
||||
|
||||
BatchSeq{T<:Seq} = Batch{T}
|
||||
|
||||
function rebatchseq(xs)
|
||||
dims = ndims(xs)-2
|
||||
T = Array{eltype(xs),dims}
|
||||
S = Array{eltype(xs),dims+1}
|
||||
B = Array{eltype(xs),dims+2}
|
||||
Batch{Seq{T,S},B}(xs)
|
||||
end
|
||||
|
@ -1,7 +1,9 @@
|
||||
import Base: eltype, size, getindex, setindex!, convert
|
||||
import Base: eltype, size, getindex, setindex!, convert, typename
|
||||
|
||||
rawbatch(xs) = xs
|
||||
|
||||
# Generic methods
|
||||
|
||||
abstract type BatchLike{T,S} <: AbstractVector{T}
|
||||
end
|
||||
|
||||
@ -21,17 +23,17 @@ function setindex!(b::BatchLike, xs, ::Colon)
|
||||
end
|
||||
end
|
||||
|
||||
typename(b::Type) = b
|
||||
typename(b::Type{<:BatchLike}) =
|
||||
Row(Juno.typ("$(b.name.name)"), text"{", typename(eltype(b)), text"}")
|
||||
typerender(B::Type) = B
|
||||
typerender(B::Type{<:BatchLike}) =
|
||||
Row(Juno.typ("$(typename(B).name)"), text"{", typerender(eltype(B)), text"}")
|
||||
|
||||
@render Juno.Inline b::BatchLike begin
|
||||
Tree(Row(typename(typeof(b)),
|
||||
Tree(Row(typerender(typeof(b)),
|
||||
Juno.fade("[$(length(b))]")),
|
||||
Juno.trim(collect(b)))
|
||||
end
|
||||
|
||||
convert{T,S}(B::Type{<:BatchLike{T,S}},storage::S) = B(storage)
|
||||
# Concrete storage
|
||||
|
||||
struct Storage{T,S} <: BatchLike{T,S}
|
||||
data::S
|
||||
@ -39,7 +41,7 @@ end
|
||||
|
||||
allequal(xs) = all(x -> x == first(xs), xs)
|
||||
|
||||
function (::Type{Storage{T,S}}){T,S}(xs, storage::S)
|
||||
function Storage{T,S}(xs, storage::S) where {T, S}
|
||||
@assert allequal(map(size, xs))
|
||||
@assert size(storage) == (length(xs), size(first(xs))...)
|
||||
for i = 1:length(xs)
|
||||
@ -48,13 +50,28 @@ function (::Type{Storage{T,S}}){T,S}(xs, storage::S)
|
||||
return Storage{T,S}(storage)
|
||||
end
|
||||
|
||||
function (::Type{Storage{T}}){T}(xs)
|
||||
function Storage{T}(xs) where T
|
||||
xs′ = map(rawbatch, xs)
|
||||
storage = similar(first(xs′), (length(xs′), size(first(xs′))...))
|
||||
Storage{T,typeof(storage)}(xs′, storage)
|
||||
end
|
||||
|
||||
function Storage(xs)
|
||||
xs = promote(xs...)
|
||||
Storage{eltype(xs)}(xs)
|
||||
end
|
||||
Storage(xs) = Storage{eltype(xs)}(xs)
|
||||
|
||||
convert{T,S}(B::Type{<:BatchLike{T,S}}, data::S) = B(data)
|
||||
|
||||
# Horrible type hacks follow this point
|
||||
|
||||
deparam(T::Type) = typename(T).wrapper
|
||||
|
||||
dimless(T::Type{<:AbstractArray}) = ndims(T) == 1 ? eltype(T) : deparam(T){eltype(T),ndims(T)-1}
|
||||
|
||||
btype(B::Type{<:BatchLike}, S::Type{<:AbstractArray}) = B{dimless(S),S}
|
||||
btype(B::Type{<:BatchLike{T}} where T, S::Type{<:AbstractArray}) = B{S}
|
||||
btype(B::Type{<:BatchLike{<:BatchLike}}, S::Type{<:AbstractArray}) =
|
||||
deparam(B){btype(eltype(B), dimless(S)),S}
|
||||
|
||||
convert{T<:BatchLike}(::Type{T}, xs::AbstractArray) =
|
||||
convert(btype(T, typeof(xs)), xs)
|
||||
|
||||
convert{T<:BatchLike}(::Type{T}, x::T) = x
|
||||
|
@ -48,7 +48,7 @@ end
|
||||
function Batched(itr, n::Integer)
|
||||
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
|
||||
itr = StatefulIter(itr)
|
||||
buf = rebatch(similar(eltype(itr)(), n, size(peek(itr))...))
|
||||
buf = convert(Batch, similar(eltype(itr)(), n, size(peek(itr))...))
|
||||
Batched(itr, buf)
|
||||
end
|
||||
|
||||
|
@ -31,7 +31,7 @@ end
|
||||
|
||||
runseq(f, xs::Tuple...) = f(xs...)
|
||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||
runseq(f, xs::Batch{<:Seq}...) = convert(Batch{Seq}, runseq(f, rawbatch.(xs)...))
|
||||
runseq(f, xs) = runseq(f, (xs...,))
|
||||
|
||||
function (m::SeqModel)(x)
|
||||
|
Loading…
Reference in New Issue
Block a user