make Batches submodule
This commit is contained in:
parent
cf8227c02f
commit
9a460e12f2
14
src/Batches/Batches.jl
Normal file
14
src/Batches/Batches.jl
Normal file
@ -0,0 +1,14 @@
|
||||
module Batches
|
||||
|
||||
using Juno, Lazy
|
||||
|
||||
export CatMat, rawbatch,
|
||||
Batch, Batched, batchone, tobatch, rebatch,
|
||||
Seq, BatchSeq, rebatchseq
|
||||
|
||||
include("catmat.jl")
|
||||
include("batch.jl")
|
||||
include("seq.jl")
|
||||
include("iter.jl")
|
||||
|
||||
end
|
@ -1,5 +1,3 @@
|
||||
export Batch, batchone, tobatch
|
||||
|
||||
struct Batch{T,S} <: AbstractVector{T}
|
||||
data::CatMat{T,S}
|
||||
end
|
||||
@ -17,6 +15,7 @@ convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
||||
Juno.trim(collect(b)))
|
||||
end
|
||||
|
||||
# TODO: figure out how to express this as a generic convert
|
||||
function rebatch(xs)
|
||||
dims = ndims(xs)-1
|
||||
T = Array{eltype(xs),dims}
|
@ -1,7 +1,5 @@
|
||||
import Base: eltype, size, getindex, setindex!, convert
|
||||
|
||||
export CatMat, rawbatch
|
||||
|
||||
struct CatMat{T,S} <: AbstractVector{T}
|
||||
data::S
|
||||
end
|
||||
@ -26,7 +24,7 @@ end
|
||||
allequal(xs) = all(x -> x == first(xs), xs)
|
||||
|
||||
function (::Type{CatMat{T,S}}){T,S}(xs, storage::S)
|
||||
@assert @>> xs map(size) allequal
|
||||
@assert allequal(map(size, xs))
|
||||
@assert size(storage) == (length(xs), size(first(xs))...)
|
||||
for i = 1:length(xs)
|
||||
storage[i, :] = xs[i]
|
@ -1,5 +1,3 @@
|
||||
export Batched
|
||||
|
||||
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
||||
|
||||
# Stateful iteration
|
26
src/Batches/seq.jl
Normal file
26
src/Batches/seq.jl
Normal file
@ -0,0 +1,26 @@
|
||||
struct Seq{T,S} <: AbstractVector{T}
|
||||
data::CatMat{T,S}
|
||||
end
|
||||
|
||||
@forward Seq.data size, eltype, getindex, setindex!, rawbatch
|
||||
|
||||
Seq(xs) = Seq(CatMat(xs))
|
||||
|
||||
convert{T,S}(::Type{Seq{T,S}},storage::S) =
|
||||
Seq{T,S}(storage)
|
||||
|
||||
@render Juno.Inline b::Seq begin
|
||||
Tree(Row(Text("Seq of "), eltype(b),
|
||||
Juno.fade("[$(length(b))]")),
|
||||
Juno.trim(collect(b)))
|
||||
end
|
||||
|
||||
BatchSeq{T<:Seq} = Batch{T}
|
||||
|
||||
function rebatchseq(xs)
|
||||
dims = ndims(xs)-2
|
||||
T = Array{eltype(xs),dims}
|
||||
S = Array{eltype(xs),dims+1}
|
||||
B = Array{eltype(xs),dims+2}
|
||||
Batch{Seq{T,S},B}(xs)
|
||||
end
|
@ -17,15 +17,13 @@ export @net, unroll, unroll1, @shapes,
|
||||
|
||||
# Zero Flux Given
|
||||
|
||||
include("Batches/Batches.jl")
|
||||
using .Batches
|
||||
|
||||
include("utils.jl")
|
||||
|
||||
include("model.jl")
|
||||
|
||||
include("dims/catmat.jl")
|
||||
include("dims/batching.jl")
|
||||
include("dims/seq.jl")
|
||||
include("dims/iter.jl")
|
||||
|
||||
include("compiler/code.jl")
|
||||
include("compiler/loops.jl")
|
||||
include("compiler/interp.jl")
|
||||
|
@ -90,7 +90,7 @@ end
|
||||
astensor(model, args...) =
|
||||
tograph(model, args...; variables = true)[3]
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
RawTensor(data::Union{Flux.Batch,Flux.Seq}) = RawTensor(Flux.rawbatch(data))
|
||||
|
||||
# Error Handling
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import DataFlow: cse
|
||||
using MacroTools: @q
|
||||
using MacroTools: @q, @>
|
||||
|
||||
function graphdef(ex, params = [])
|
||||
@capture(shortdef(ex), (args__,) -> body_)
|
||||
|
@ -22,6 +22,31 @@ end
|
||||
|
||||
update!(m::Stateful, η) = update!(m.model, η)
|
||||
|
||||
# Seq Models
|
||||
|
||||
struct SeqModel
|
||||
model
|
||||
steps::Int
|
||||
end
|
||||
|
||||
runseq(f, xs::Tuple...) = f(xs...)
|
||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||
runseq(f, xs) = runseq(f, (xs...,))
|
||||
|
||||
function (m::SeqModel)(x)
|
||||
runseq(x) do x
|
||||
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||
m.model(x)
|
||||
end
|
||||
end
|
||||
|
||||
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||
|
||||
update!(m::SeqModel, η) = update!(m.model, η)
|
||||
|
||||
graph(m::SeqModel) = graph(m.model)
|
||||
|
||||
# Recurrent Graphs
|
||||
|
||||
struct Offset
|
||||
|
@ -1,53 +0,0 @@
|
||||
export Seq, BatchSeq
|
||||
|
||||
struct Seq{T,S} <: AbstractVector{T}
|
||||
data::CatMat{T,S}
|
||||
end
|
||||
|
||||
@forward Seq.data size, eltype, getindex, setindex!, rawbatch
|
||||
|
||||
Seq(xs) = Seq(CatMat(xs))
|
||||
|
||||
convert{T,S}(::Type{Seq{T,S}},storage::S) =
|
||||
Seq{T,S}(storage)
|
||||
|
||||
@render Juno.Inline b::Seq begin
|
||||
Tree(Row(Text("Seq of "), eltype(b),
|
||||
Juno.fade("[$(length(b))]")),
|
||||
Juno.trim(collect(b)))
|
||||
end
|
||||
|
||||
BatchSeq{T<:Seq} = Batch{T}
|
||||
|
||||
function rebatchseq(xs)
|
||||
dims = ndims(xs)-2
|
||||
T = Array{eltype(xs),dims}
|
||||
S = Array{eltype(xs),dims+1}
|
||||
B = Array{eltype(xs),dims+2}
|
||||
Batch{Seq{T,S},B}(xs)
|
||||
end
|
||||
|
||||
# SeqModel wrapper layer for convenience
|
||||
|
||||
struct SeqModel
|
||||
model
|
||||
steps::Int
|
||||
end
|
||||
|
||||
runseq(f, xs::Tuple...) = f(xs...)
|
||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||
runseq(f, xs) = runseq(f, (xs...,))
|
||||
|
||||
function (m::SeqModel)(x)
|
||||
runseq(x) do x
|
||||
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||
m.model(x)
|
||||
end
|
||||
end
|
||||
|
||||
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||
|
||||
update!(m::SeqModel, η) = update!(m.model, η)
|
||||
|
||||
graph(m::SeqModel) = graph(m.model)
|
@ -1,4 +1,4 @@
|
||||
using Flux, DataFlow, MacroTools, Base.Test
|
||||
using Flux, Flux.Batches, DataFlow, MacroTools, Base.Test
|
||||
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user