From 13f4af24924a17ec9a4ea8d81c90472c74f05629 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 5 Jun 2017 22:49:31 +0100 Subject: [PATCH] generic and consistent conversions --- src/Batches/Batches.jl | 4 +--- src/Batches/batch.jl | 18 ------------------ src/Batches/catmat.jl | 41 +++++++++++++++++++++++++++++------------ src/Batches/iter.jl | 2 +- src/compiler/loops.jl | 2 +- 5 files changed, 32 insertions(+), 35 deletions(-) diff --git a/src/Batches/Batches.jl b/src/Batches/Batches.jl index 1abb41b8..433d0bdf 100644 --- a/src/Batches/Batches.jl +++ b/src/Batches/Batches.jl @@ -3,9 +3,7 @@ module Batches using Juno, Lazy using Juno: Tree, Row -export Storage, rawbatch, - Batch, Batched, batchone, tobatch, rebatch, - Seq, BatchSeq, rebatchseq +export Batch, Batched, Seq, rawbatch, batchone include("catmat.jl") include("batch.jl") diff --git a/src/Batches/batch.jl b/src/Batches/batch.jl index f1fbd812..10958516 100644 --- a/src/Batches/batch.jl +++ b/src/Batches/batch.jl @@ -6,14 +6,6 @@ end Batch(xs) = Batch(Storage(xs)) -# TODO: figure out how to express this as a generic convert -function rebatch(xs) - dims = ndims(xs)-1 - T = Array{eltype(xs),dims} - B = Array{eltype(xs),dims+1} - Batch{T,B}(xs) -end - convertel(T::Type, xs::Batch) = eltype(eltype(xs)) isa T ? xs : Batch(map(x->convertel(T, x), xs)) @@ -31,13 +23,3 @@ struct Seq{T,S} <: BatchLike{T,S} end Seq(xs) = Seq(Storage(xs)) - -BatchSeq{T<:Seq} = Batch{T} - -function rebatchseq(xs) - dims = ndims(xs)-2 - T = Array{eltype(xs),dims} - S = Array{eltype(xs),dims+1} - B = Array{eltype(xs),dims+2} - Batch{Seq{T,S},B}(xs) -end diff --git a/src/Batches/catmat.jl b/src/Batches/catmat.jl index 2c24f1c6..43651e84 100644 --- a/src/Batches/catmat.jl +++ b/src/Batches/catmat.jl @@ -1,7 +1,9 @@ -import Base: eltype, size, getindex, setindex!, convert +import Base: eltype, size, getindex, setindex!, convert, typename rawbatch(xs) = xs +# Generic methods + abstract type BatchLike{T,S} <: AbstractVector{T} end @@ -21,17 +23,17 @@ function setindex!(b::BatchLike, xs, ::Colon) end end -typename(b::Type) = b -typename(b::Type{<:BatchLike}) = - Row(Juno.typ("$(b.name.name)"), text"{", typename(eltype(b)), text"}") +typerender(B::Type) = B +typerender(B::Type{<:BatchLike}) = + Row(Juno.typ("$(typename(B).name)"), text"{", typerender(eltype(B)), text"}") @render Juno.Inline b::BatchLike begin - Tree(Row(typename(typeof(b)), + Tree(Row(typerender(typeof(b)), Juno.fade("[$(length(b))]")), Juno.trim(collect(b))) end -convert{T,S}(B::Type{<:BatchLike{T,S}},storage::S) = B(storage) +# Concrete storage struct Storage{T,S} <: BatchLike{T,S} data::S @@ -39,7 +41,7 @@ end allequal(xs) = all(x -> x == first(xs), xs) -function (::Type{Storage{T,S}}){T,S}(xs, storage::S) +function Storage{T,S}(xs, storage::S) where {T, S} @assert allequal(map(size, xs)) @assert size(storage) == (length(xs), size(first(xs))...) for i = 1:length(xs) @@ -48,13 +50,28 @@ function (::Type{Storage{T,S}}){T,S}(xs, storage::S) return Storage{T,S}(storage) end -function (::Type{Storage{T}}){T}(xs) +function Storage{T}(xs) where T xs′ = map(rawbatch, xs) storage = similar(first(xs′), (length(xs′), size(first(xs′))...)) Storage{T,typeof(storage)}(xs′, storage) end -function Storage(xs) - xs = promote(xs...) - Storage{eltype(xs)}(xs) -end +Storage(xs) = Storage{eltype(xs)}(xs) + +convert{T,S}(B::Type{<:BatchLike{T,S}}, data::S) = B(data) + +# Horrible type hacks follow this point + +deparam(T::Type) = typename(T).wrapper + +dimless(T::Type{<:AbstractArray}) = ndims(T) == 1 ? eltype(T) : deparam(T){eltype(T),ndims(T)-1} + +btype(B::Type{<:BatchLike}, S::Type{<:AbstractArray}) = B{dimless(S),S} +btype(B::Type{<:BatchLike{T}} where T, S::Type{<:AbstractArray}) = B{S} +btype(B::Type{<:BatchLike{<:BatchLike}}, S::Type{<:AbstractArray}) = + deparam(B){btype(eltype(B), dimless(S)),S} + +convert{T<:BatchLike}(::Type{T}, xs::AbstractArray) = + convert(btype(T, typeof(xs)), xs) + +convert{T<:BatchLike}(::Type{T}, x::T) = x diff --git a/src/Batches/iter.jl b/src/Batches/iter.jl index 95d85798..051541e6 100644 --- a/src/Batches/iter.jl +++ b/src/Batches/iter.jl @@ -48,7 +48,7 @@ end function Batched(itr, n::Integer) n >= 1 || throw(ArgumentError("batch size must be >= 1")) itr = StatefulIter(itr) - buf = rebatch(similar(eltype(itr)(), n, size(peek(itr))...)) + buf = convert(Batch, similar(eltype(itr)(), n, size(peek(itr))...)) Batched(itr, buf) end diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 736e814a..33f38cdd 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -31,7 +31,7 @@ end runseq(f, xs::Tuple...) = f(xs...) runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2) -runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...)) +runseq(f, xs::Batch{<:Seq}...) = convert(Batch{Seq}, runseq(f, rawbatch.(xs)...)) runseq(f, xs) = runseq(f, (xs...,)) function (m::SeqModel)(x)