diff --git a/src/Batches/Batches.jl b/src/Batches/Batches.jl new file mode 100644 index 00000000..0cd77b16 --- /dev/null +++ b/src/Batches/Batches.jl @@ -0,0 +1,14 @@ +module Batches + +using Juno, Lazy + +export CatMat, rawbatch, + Batch, Batched, batchone, tobatch, rebatch, + Seq, BatchSeq, rebatchseq + +include("catmat.jl") +include("batch.jl") +include("seq.jl") +include("iter.jl") + +end diff --git a/src/dims/batching.jl b/src/Batches/batch.jl similarity index 92% rename from src/dims/batching.jl rename to src/Batches/batch.jl index bb82d11e..d6a8e350 100644 --- a/src/dims/batching.jl +++ b/src/Batches/batch.jl @@ -1,5 +1,3 @@ -export Batch, batchone, tobatch - struct Batch{T,S} <: AbstractVector{T} data::CatMat{T,S} end @@ -17,6 +15,7 @@ convert{T,S}(::Type{Batch{T,S}},storage::S) = Juno.trim(collect(b))) end +# TODO: figure out how to express this as a generic convert function rebatch(xs) dims = ndims(xs)-1 T = Array{eltype(xs),dims} diff --git a/src/dims/catmat.jl b/src/Batches/catmat.jl similarity index 95% rename from src/dims/catmat.jl rename to src/Batches/catmat.jl index b4c170e2..c550f451 100644 --- a/src/dims/catmat.jl +++ b/src/Batches/catmat.jl @@ -1,7 +1,5 @@ import Base: eltype, size, getindex, setindex!, convert -export CatMat, rawbatch - struct CatMat{T,S} <: AbstractVector{T} data::S end @@ -26,7 +24,7 @@ end allequal(xs) = all(x -> x == first(xs), xs) function (::Type{CatMat{T,S}}){T,S}(xs, storage::S) - @assert @>> xs map(size) allequal + @assert allequal(map(size, xs)) @assert size(storage) == (length(xs), size(first(xs))...) for i = 1:length(xs) storage[i, :] = xs[i] diff --git a/src/dims/iter.jl b/src/Batches/iter.jl similarity index 98% rename from src/dims/iter.jl rename to src/Batches/iter.jl index d35f14d3..95d85798 100644 --- a/src/dims/iter.jl +++ b/src/Batches/iter.jl @@ -1,5 +1,3 @@ -export Batched - import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length # Stateful iteration diff --git a/src/Batches/seq.jl b/src/Batches/seq.jl new file mode 100644 index 00000000..5fd2e1b6 --- /dev/null +++ b/src/Batches/seq.jl @@ -0,0 +1,26 @@ +struct Seq{T,S} <: AbstractVector{T} + data::CatMat{T,S} +end + +@forward Seq.data size, eltype, getindex, setindex!, rawbatch + +Seq(xs) = Seq(CatMat(xs)) + +convert{T,S}(::Type{Seq{T,S}},storage::S) = + Seq{T,S}(storage) + +@render Juno.Inline b::Seq begin + Tree(Row(Text("Seq of "), eltype(b), + Juno.fade("[$(length(b))]")), + Juno.trim(collect(b))) +end + +BatchSeq{T<:Seq} = Batch{T} + +function rebatchseq(xs) + dims = ndims(xs)-2 + T = Array{eltype(xs),dims} + S = Array{eltype(xs),dims+1} + B = Array{eltype(xs),dims+2} + Batch{Seq{T,S},B}(xs) +end diff --git a/src/Flux.jl b/src/Flux.jl index 3bb87624..3b353dca 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -17,15 +17,13 @@ export @net, unroll, unroll1, @shapes, # Zero Flux Given +include("Batches/Batches.jl") +using .Batches + include("utils.jl") include("model.jl") -include("dims/catmat.jl") -include("dims/batching.jl") -include("dims/seq.jl") -include("dims/iter.jl") - include("compiler/code.jl") include("compiler/loops.jl") include("compiler/interp.jl") diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 9a94b38c..71499ac8 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -90,7 +90,7 @@ end astensor(model, args...) = tograph(model, args...; variables = true)[3] -RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) +RawTensor(data::Union{Flux.Batch,Flux.Seq}) = RawTensor(Flux.rawbatch(data)) # Error Handling diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 1ef4af64..5da3e34c 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -1,5 +1,5 @@ import DataFlow: cse -using MacroTools: @q +using MacroTools: @q, @> function graphdef(ex, params = []) @capture(shortdef(ex), (args__,) -> body_) diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index aae9d700..736e814a 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -22,6 +22,31 @@ end update!(m::Stateful, η) = update!(m.model, η) +# Seq Models + +struct SeqModel + model + steps::Int +end + +runseq(f, xs::Tuple...) = f(xs...) +runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2) +runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...)) +runseq(f, xs) = runseq(f, (xs...,)) + +function (m::SeqModel)(x) + runseq(x) do x + @assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))" + m.model(x) + end +end + +back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),) + +update!(m::SeqModel, η) = update!(m.model, η) + +graph(m::SeqModel) = graph(m.model) + # Recurrent Graphs struct Offset diff --git a/src/dims/seq.jl b/src/dims/seq.jl deleted file mode 100644 index 206527a3..00000000 --- a/src/dims/seq.jl +++ /dev/null @@ -1,53 +0,0 @@ -export Seq, BatchSeq - -struct Seq{T,S} <: AbstractVector{T} - data::CatMat{T,S} -end - -@forward Seq.data size, eltype, getindex, setindex!, rawbatch - -Seq(xs) = Seq(CatMat(xs)) - -convert{T,S}(::Type{Seq{T,S}},storage::S) = - Seq{T,S}(storage) - -@render Juno.Inline b::Seq begin - Tree(Row(Text("Seq of "), eltype(b), - Juno.fade("[$(length(b))]")), - Juno.trim(collect(b))) -end - -BatchSeq{T<:Seq} = Batch{T} - -function rebatchseq(xs) - dims = ndims(xs)-2 - T = Array{eltype(xs),dims} - S = Array{eltype(xs),dims+1} - B = Array{eltype(xs),dims+2} - Batch{Seq{T,S},B}(xs) -end - -# SeqModel wrapper layer for convenience - -struct SeqModel - model - steps::Int -end - -runseq(f, xs::Tuple...) = f(xs...) -runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2) -runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...)) -runseq(f, xs) = runseq(f, (xs...,)) - -function (m::SeqModel)(x) - runseq(x) do x - @assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))" - m.model(x) - end -end - -back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),) - -update!(m::SeqModel, η) = update!(m.model, η) - -graph(m::SeqModel) = graph(m.model) diff --git a/test/runtests.jl b/test/runtests.jl index 8dd1dd8e..7e949956 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using Flux, DataFlow, MacroTools, Base.Test +using Flux, Flux.Batches, DataFlow, MacroTools, Base.Test using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten using DataFlow: Line, Frame