diff --git a/src/Flux.jl b/src/Flux.jl index 5972bd44..dafa6e78 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,11 +7,13 @@ using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk, iscyclic, Constant, constant, isconstant, group, Split, splitnode, detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode, spliceinputs, bumpinputs, Line, Frame, applylines +using DataFlow.Interpreter using MacroTools: @q using Juno: Tree, Row # Zero Flux Given +include("dims/utils.jl") include("dims/catmat.jl") include("dims/batching.jl") include("dims/seq.jl") diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index bd5803fd..4419e680 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -1,5 +1,3 @@ -using DataFlow.Interpreter - function astuple(xs::Vertex) isconstant(xs) && value(xs).value isa Tuple ? value(xs).value : xs isa Vertex && value(xs) == tuple ? inputs(xs) : diff --git a/src/compiler/shape.jl b/src/compiler/shape.jl index e112b07b..709590ad 100644 --- a/src/compiler/shape.jl +++ b/src/compiler/shape.jl @@ -1,5 +1,3 @@ -using DataFlow.Interpreter - export @shapes Dims{N} = NTuple{N,Int} diff --git a/src/dims/seq.jl b/src/dims/seq.jl index fe4fe6cc..ce1eecc1 100644 --- a/src/dims/seq.jl +++ b/src/dims/seq.jl @@ -18,3 +18,11 @@ convert{T,S}(::Type{Seq{T,S}},storage::S) = end BatchSeq{T<:Seq} = Batch{T} + +function rebatchseq(xs) + dims = ndims(xs)-2 + T = Array{eltype(xs),dims} + S = Array{eltype(xs),dims+1} + B = Array{eltype(xs),dims+2} + Batch{Seq{T,S},B}(xs) +end diff --git a/src/dims/utils.jl b/src/dims/utils.jl new file mode 100644 index 00000000..680c85a1 --- /dev/null +++ b/src/dims/utils.jl @@ -0,0 +1,4 @@ +unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) + +stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) +unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] diff --git a/src/model.jl b/src/model.jl index 9a4c199a..f8f7f6bb 100644 --- a/src/model.jl +++ b/src/model.jl @@ -145,6 +145,13 @@ struct SeqModel steps::Int end -# TODO: multi input -# TODO: lift sequences -(m::SeqModel)(x) = m.model(x) +(m::SeqModel)(x::Tuple) = m.model(x) + +splitseq(xs) = rebatch.(unstack(rawbatch(xs), 2)) +joinseq(xs) = rebatchseq(stack(rawbatch.(xs), 2)) + +function (m::SeqModel)(x::Union{Seq,BatchSeq}) + runbatched(x) do x + joinseq(m.model((splitseq(x)...,))) + end +end