working seqmodel

This commit is contained in:
Mike J Innes 2017-03-28 19:54:32 +01:00
parent ab8f0c2dc8
commit 2082d9db5c
6 changed files with 24 additions and 7 deletions

View File

@ -7,11 +7,13 @@ using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
spliceinputs, bumpinputs, Line, Frame, applylines
using DataFlow.Interpreter
using MacroTools: @q
using Juno: Tree, Row
# Zero Flux Given
include("dims/utils.jl")
include("dims/catmat.jl")
include("dims/batching.jl")
include("dims/seq.jl")

View File

@ -1,5 +1,3 @@
using DataFlow.Interpreter
function astuple(xs::Vertex)
isconstant(xs) && value(xs).value isa Tuple ? value(xs).value :
xs isa Vertex && value(xs) == tuple ? inputs(xs) :

View File

@ -1,5 +1,3 @@
using DataFlow.Interpreter
export @shapes
Dims{N} = NTuple{N,Int}

View File

@ -18,3 +18,11 @@ convert{T,S}(::Type{Seq{T,S}},storage::S) =
end
BatchSeq{T<:Seq} = Batch{T}
function rebatchseq(xs)
dims = ndims(xs)-2
T = Array{eltype(xs),dims}
S = Array{eltype(xs),dims+1}
B = Array{eltype(xs),dims+2}
Batch{Seq{T,S},B}(xs)
end

4
src/dims/utils.jl Normal file
View File

@ -0,0 +1,4 @@
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]

View File

@ -145,6 +145,13 @@ struct SeqModel
steps::Int
end
# TODO: multi input
# TODO: lift sequences
(m::SeqModel)(x) = m.model(x)
(m::SeqModel)(x::Tuple) = m.model(x)
splitseq(xs) = rebatch.(unstack(rawbatch(xs), 2))
joinseq(xs) = rebatchseq(stack(rawbatch.(xs), 2))
function (m::SeqModel)(x::Union{Seq,BatchSeq})
runbatched(x) do x
joinseq(m.model((splitseq(x)...,)))
end
end