working seqmodel
This commit is contained in:
parent
ab8f0c2dc8
commit
2082d9db5c
@ -7,11 +7,13 @@ using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||
spliceinputs, bumpinputs, Line, Frame, applylines
|
||||
using DataFlow.Interpreter
|
||||
using MacroTools: @q
|
||||
using Juno: Tree, Row
|
||||
|
||||
# Zero Flux Given
|
||||
|
||||
include("dims/utils.jl")
|
||||
include("dims/catmat.jl")
|
||||
include("dims/batching.jl")
|
||||
include("dims/seq.jl")
|
||||
|
@ -1,5 +1,3 @@
|
||||
using DataFlow.Interpreter
|
||||
|
||||
function astuple(xs::Vertex)
|
||||
isconstant(xs) && value(xs).value isa Tuple ? value(xs).value :
|
||||
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
|
||||
|
@ -1,5 +1,3 @@
|
||||
using DataFlow.Interpreter
|
||||
|
||||
export @shapes
|
||||
|
||||
Dims{N} = NTuple{N,Int}
|
||||
|
@ -18,3 +18,11 @@ convert{T,S}(::Type{Seq{T,S}},storage::S) =
|
||||
end
|
||||
|
||||
BatchSeq{T<:Seq} = Batch{T}
|
||||
|
||||
function rebatchseq(xs)
|
||||
dims = ndims(xs)-2
|
||||
T = Array{eltype(xs),dims}
|
||||
S = Array{eltype(xs),dims+1}
|
||||
B = Array{eltype(xs),dims+2}
|
||||
Batch{Seq{T,S},B}(xs)
|
||||
end
|
||||
|
4
src/dims/utils.jl
Normal file
4
src/dims/utils.jl
Normal file
@ -0,0 +1,4 @@
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
||||
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
13
src/model.jl
13
src/model.jl
@ -145,6 +145,13 @@ struct SeqModel
|
||||
steps::Int
|
||||
end
|
||||
|
||||
# TODO: multi input
|
||||
# TODO: lift sequences
|
||||
(m::SeqModel)(x) = m.model(x)
|
||||
(m::SeqModel)(x::Tuple) = m.model(x)
|
||||
|
||||
splitseq(xs) = rebatch.(unstack(rawbatch(xs), 2))
|
||||
joinseq(xs) = rebatchseq(stack(rawbatch.(xs), 2))
|
||||
|
||||
function (m::SeqModel)(x::Union{Seq,BatchSeq})
|
||||
runbatched(x) do x
|
||||
joinseq(m.model((splitseq(x)...,)))
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user