make Batches submodule

This commit is contained in:
Mike J Innes 2017-06-05 16:56:44 +01:00
parent cf8227c02f
commit 9a460e12f2
11 changed files with 73 additions and 68 deletions

14
src/Batches/Batches.jl Normal file
View File

@ -0,0 +1,14 @@
module Batches
using Juno, Lazy
export CatMat, rawbatch,
Batch, Batched, batchone, tobatch, rebatch,
Seq, BatchSeq, rebatchseq
include("catmat.jl")
include("batch.jl")
include("seq.jl")
include("iter.jl")
end

View File

@ -1,5 +1,3 @@
export Batch, batchone, tobatch
struct Batch{T,S} <: AbstractVector{T}
data::CatMat{T,S}
end
@ -17,6 +15,7 @@ convert{T,S}(::Type{Batch{T,S}},storage::S) =
Juno.trim(collect(b)))
end
# TODO: figure out how to express this as a generic convert
function rebatch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}

View File

@ -1,7 +1,5 @@
import Base: eltype, size, getindex, setindex!, convert
export CatMat, rawbatch
struct CatMat{T,S} <: AbstractVector{T}
data::S
end
@ -26,7 +24,7 @@ end
allequal(xs) = all(x -> x == first(xs), xs)
function (::Type{CatMat{T,S}}){T,S}(xs, storage::S)
@assert @>> xs map(size) allequal
@assert allequal(map(size, xs))
@assert size(storage) == (length(xs), size(first(xs))...)
for i = 1:length(xs)
storage[i, :] = xs[i]

View File

@ -1,5 +1,3 @@
export Batched
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
# Stateful iteration

26
src/Batches/seq.jl Normal file
View File

@ -0,0 +1,26 @@
struct Seq{T,S} <: AbstractVector{T}
data::CatMat{T,S}
end
@forward Seq.data size, eltype, getindex, setindex!, rawbatch
Seq(xs) = Seq(CatMat(xs))
convert{T,S}(::Type{Seq{T,S}},storage::S) =
Seq{T,S}(storage)
@render Juno.Inline b::Seq begin
Tree(Row(Text("Seq of "), eltype(b),
Juno.fade("[$(length(b))]")),
Juno.trim(collect(b)))
end
BatchSeq{T<:Seq} = Batch{T}
function rebatchseq(xs)
dims = ndims(xs)-2
T = Array{eltype(xs),dims}
S = Array{eltype(xs),dims+1}
B = Array{eltype(xs),dims+2}
Batch{Seq{T,S},B}(xs)
end

View File

@ -17,15 +17,13 @@ export @net, unroll, unroll1, @shapes,
# Zero Flux Given
include("Batches/Batches.jl")
using .Batches
include("utils.jl")
include("model.jl")
include("dims/catmat.jl")
include("dims/batching.jl")
include("dims/seq.jl")
include("dims/iter.jl")
include("compiler/code.jl")
include("compiler/loops.jl")
include("compiler/interp.jl")

View File

@ -90,7 +90,7 @@ end
astensor(model, args...) =
tograph(model, args...; variables = true)[3]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
RawTensor(data::Union{Flux.Batch,Flux.Seq}) = RawTensor(Flux.rawbatch(data))
# Error Handling

View File

@ -1,5 +1,5 @@
import DataFlow: cse
using MacroTools: @q
using MacroTools: @q, @>
function graphdef(ex, params = [])
@capture(shortdef(ex), (args__,) -> body_)

View File

@ -22,6 +22,31 @@ end
update!(m::Stateful, η) = update!(m.model, η)
# Seq Models
struct SeqModel
model
steps::Int
end
runseq(f, xs::Tuple...) = f(xs...)
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
runseq(f, xs) = runseq(f, (xs...,))
function (m::SeqModel)(x)
runseq(x) do x
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
m.model(x)
end
end
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
update!(m::SeqModel, η) = update!(m.model, η)
graph(m::SeqModel) = graph(m.model)
# Recurrent Graphs
struct Offset

View File

@ -1,53 +0,0 @@
export Seq, BatchSeq
struct Seq{T,S} <: AbstractVector{T}
data::CatMat{T,S}
end
@forward Seq.data size, eltype, getindex, setindex!, rawbatch
Seq(xs) = Seq(CatMat(xs))
convert{T,S}(::Type{Seq{T,S}},storage::S) =
Seq{T,S}(storage)
@render Juno.Inline b::Seq begin
Tree(Row(Text("Seq of "), eltype(b),
Juno.fade("[$(length(b))]")),
Juno.trim(collect(b)))
end
BatchSeq{T<:Seq} = Batch{T}
function rebatchseq(xs)
dims = ndims(xs)-2
T = Array{eltype(xs),dims}
S = Array{eltype(xs),dims+1}
B = Array{eltype(xs),dims+2}
Batch{Seq{T,S},B}(xs)
end
# SeqModel wrapper layer for convenience
struct SeqModel
model
steps::Int
end
runseq(f, xs::Tuple...) = f(xs...)
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
runseq(f, xs) = runseq(f, (xs...,))
function (m::SeqModel)(x)
runseq(x) do x
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
m.model(x)
end
end
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
update!(m::SeqModel, η) = update!(m.model, η)
graph(m::SeqModel) = graph(m.model)

View File

@ -1,4 +1,4 @@
using Flux, DataFlow, MacroTools, Base.Test
using Flux, Flux.Batches, DataFlow, MacroTools, Base.Test
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
using DataFlow: Line, Frame