generic and consistent conversions

This commit is contained in:
Mike J Innes 2017-06-05 22:49:31 +01:00
parent f0880f89cc
commit 13f4af2492
5 changed files with 32 additions and 35 deletions

View File

@ -3,9 +3,7 @@ module Batches
using Juno, Lazy using Juno, Lazy
using Juno: Tree, Row using Juno: Tree, Row
export Storage, rawbatch, export Batch, Batched, Seq, rawbatch, batchone
Batch, Batched, batchone, tobatch, rebatch,
Seq, BatchSeq, rebatchseq
include("catmat.jl") include("catmat.jl")
include("batch.jl") include("batch.jl")

View File

@ -6,14 +6,6 @@ end
Batch(xs) = Batch(Storage(xs)) Batch(xs) = Batch(Storage(xs))
# TODO: figure out how to express this as a generic convert
function rebatch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}
B = Array{eltype(xs),dims+1}
Batch{T,B}(xs)
end
convertel(T::Type, xs::Batch) = convertel(T::Type, xs::Batch) =
eltype(eltype(xs)) isa T ? xs : eltype(eltype(xs)) isa T ? xs :
Batch(map(x->convertel(T, x), xs)) Batch(map(x->convertel(T, x), xs))
@ -31,13 +23,3 @@ struct Seq{T,S} <: BatchLike{T,S}
end end
Seq(xs) = Seq(Storage(xs)) Seq(xs) = Seq(Storage(xs))
BatchSeq{T<:Seq} = Batch{T}
function rebatchseq(xs)
dims = ndims(xs)-2
T = Array{eltype(xs),dims}
S = Array{eltype(xs),dims+1}
B = Array{eltype(xs),dims+2}
Batch{Seq{T,S},B}(xs)
end

View File

@ -1,7 +1,9 @@
import Base: eltype, size, getindex, setindex!, convert import Base: eltype, size, getindex, setindex!, convert, typename
rawbatch(xs) = xs rawbatch(xs) = xs
# Generic methods
abstract type BatchLike{T,S} <: AbstractVector{T} abstract type BatchLike{T,S} <: AbstractVector{T}
end end
@ -21,17 +23,17 @@ function setindex!(b::BatchLike, xs, ::Colon)
end end
end end
typename(b::Type) = b typerender(B::Type) = B
typename(b::Type{<:BatchLike}) = typerender(B::Type{<:BatchLike}) =
Row(Juno.typ("$(b.name.name)"), text"{", typename(eltype(b)), text"}") Row(Juno.typ("$(typename(B).name)"), text"{", typerender(eltype(B)), text"}")
@render Juno.Inline b::BatchLike begin @render Juno.Inline b::BatchLike begin
Tree(Row(typename(typeof(b)), Tree(Row(typerender(typeof(b)),
Juno.fade("[$(length(b))]")), Juno.fade("[$(length(b))]")),
Juno.trim(collect(b))) Juno.trim(collect(b)))
end end
convert{T,S}(B::Type{<:BatchLike{T,S}},storage::S) = B(storage) # Concrete storage
struct Storage{T,S} <: BatchLike{T,S} struct Storage{T,S} <: BatchLike{T,S}
data::S data::S
@ -39,7 +41,7 @@ end
allequal(xs) = all(x -> x == first(xs), xs) allequal(xs) = all(x -> x == first(xs), xs)
function (::Type{Storage{T,S}}){T,S}(xs, storage::S) function Storage{T,S}(xs, storage::S) where {T, S}
@assert allequal(map(size, xs)) @assert allequal(map(size, xs))
@assert size(storage) == (length(xs), size(first(xs))...) @assert size(storage) == (length(xs), size(first(xs))...)
for i = 1:length(xs) for i = 1:length(xs)
@ -48,13 +50,28 @@ function (::Type{Storage{T,S}}){T,S}(xs, storage::S)
return Storage{T,S}(storage) return Storage{T,S}(storage)
end end
function (::Type{Storage{T}}){T}(xs) function Storage{T}(xs) where T
xs = map(rawbatch, xs) xs = map(rawbatch, xs)
storage = similar(first(xs), (length(xs), size(first(xs))...)) storage = similar(first(xs), (length(xs), size(first(xs))...))
Storage{T,typeof(storage)}(xs, storage) Storage{T,typeof(storage)}(xs, storage)
end end
function Storage(xs) Storage(xs) = Storage{eltype(xs)}(xs)
xs = promote(xs...)
Storage{eltype(xs)}(xs) convert{T,S}(B::Type{<:BatchLike{T,S}}, data::S) = B(data)
end
# Horrible type hacks follow this point
deparam(T::Type) = typename(T).wrapper
dimless(T::Type{<:AbstractArray}) = ndims(T) == 1 ? eltype(T) : deparam(T){eltype(T),ndims(T)-1}
btype(B::Type{<:BatchLike}, S::Type{<:AbstractArray}) = B{dimless(S),S}
btype(B::Type{<:BatchLike{T}} where T, S::Type{<:AbstractArray}) = B{S}
btype(B::Type{<:BatchLike{<:BatchLike}}, S::Type{<:AbstractArray}) =
deparam(B){btype(eltype(B), dimless(S)),S}
convert{T<:BatchLike}(::Type{T}, xs::AbstractArray) =
convert(btype(T, typeof(xs)), xs)
convert{T<:BatchLike}(::Type{T}, x::T) = x

View File

@ -48,7 +48,7 @@ end
function Batched(itr, n::Integer) function Batched(itr, n::Integer)
n >= 1 || throw(ArgumentError("batch size must be >= 1")) n >= 1 || throw(ArgumentError("batch size must be >= 1"))
itr = StatefulIter(itr) itr = StatefulIter(itr)
buf = rebatch(similar(eltype(itr)(), n, size(peek(itr))...)) buf = convert(Batch, similar(eltype(itr)(), n, size(peek(itr))...))
Batched(itr, buf) Batched(itr, buf)
end end

View File

@ -31,7 +31,7 @@ end
runseq(f, xs::Tuple...) = f(xs...) runseq(f, xs::Tuple...) = f(xs...)
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2) runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...)) runseq(f, xs::Batch{<:Seq}...) = convert(Batch{Seq}, runseq(f, rawbatch.(xs)...))
runseq(f, xs) = runseq(f, (xs...,)) runseq(f, xs) = runseq(f, (xs...,))
function (m::SeqModel)(x) function (m::SeqModel)(x)