make Batches submodule
This commit is contained in:
parent
cf8227c02f
commit
9a460e12f2
14
src/Batches/Batches.jl
Normal file
14
src/Batches/Batches.jl
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
module Batches
|
||||||
|
|
||||||
|
using Juno, Lazy
|
||||||
|
|
||||||
|
export CatMat, rawbatch,
|
||||||
|
Batch, Batched, batchone, tobatch, rebatch,
|
||||||
|
Seq, BatchSeq, rebatchseq
|
||||||
|
|
||||||
|
include("catmat.jl")
|
||||||
|
include("batch.jl")
|
||||||
|
include("seq.jl")
|
||||||
|
include("iter.jl")
|
||||||
|
|
||||||
|
end
|
@ -1,5 +1,3 @@
|
|||||||
export Batch, batchone, tobatch
|
|
||||||
|
|
||||||
struct Batch{T,S} <: AbstractVector{T}
|
struct Batch{T,S} <: AbstractVector{T}
|
||||||
data::CatMat{T,S}
|
data::CatMat{T,S}
|
||||||
end
|
end
|
||||||
@ -17,6 +15,7 @@ convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
|||||||
Juno.trim(collect(b)))
|
Juno.trim(collect(b)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# TODO: figure out how to express this as a generic convert
|
||||||
function rebatch(xs)
|
function rebatch(xs)
|
||||||
dims = ndims(xs)-1
|
dims = ndims(xs)-1
|
||||||
T = Array{eltype(xs),dims}
|
T = Array{eltype(xs),dims}
|
@ -1,7 +1,5 @@
|
|||||||
import Base: eltype, size, getindex, setindex!, convert
|
import Base: eltype, size, getindex, setindex!, convert
|
||||||
|
|
||||||
export CatMat, rawbatch
|
|
||||||
|
|
||||||
struct CatMat{T,S} <: AbstractVector{T}
|
struct CatMat{T,S} <: AbstractVector{T}
|
||||||
data::S
|
data::S
|
||||||
end
|
end
|
||||||
@ -26,7 +24,7 @@ end
|
|||||||
allequal(xs) = all(x -> x == first(xs), xs)
|
allequal(xs) = all(x -> x == first(xs), xs)
|
||||||
|
|
||||||
function (::Type{CatMat{T,S}}){T,S}(xs, storage::S)
|
function (::Type{CatMat{T,S}}){T,S}(xs, storage::S)
|
||||||
@assert @>> xs map(size) allequal
|
@assert allequal(map(size, xs))
|
||||||
@assert size(storage) == (length(xs), size(first(xs))...)
|
@assert size(storage) == (length(xs), size(first(xs))...)
|
||||||
for i = 1:length(xs)
|
for i = 1:length(xs)
|
||||||
storage[i, :] = xs[i]
|
storage[i, :] = xs[i]
|
@ -1,5 +1,3 @@
|
|||||||
export Batched
|
|
||||||
|
|
||||||
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
||||||
|
|
||||||
# Stateful iteration
|
# Stateful iteration
|
26
src/Batches/seq.jl
Normal file
26
src/Batches/seq.jl
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
struct Seq{T,S} <: AbstractVector{T}
|
||||||
|
data::CatMat{T,S}
|
||||||
|
end
|
||||||
|
|
||||||
|
@forward Seq.data size, eltype, getindex, setindex!, rawbatch
|
||||||
|
|
||||||
|
Seq(xs) = Seq(CatMat(xs))
|
||||||
|
|
||||||
|
convert{T,S}(::Type{Seq{T,S}},storage::S) =
|
||||||
|
Seq{T,S}(storage)
|
||||||
|
|
||||||
|
@render Juno.Inline b::Seq begin
|
||||||
|
Tree(Row(Text("Seq of "), eltype(b),
|
||||||
|
Juno.fade("[$(length(b))]")),
|
||||||
|
Juno.trim(collect(b)))
|
||||||
|
end
|
||||||
|
|
||||||
|
BatchSeq{T<:Seq} = Batch{T}
|
||||||
|
|
||||||
|
function rebatchseq(xs)
|
||||||
|
dims = ndims(xs)-2
|
||||||
|
T = Array{eltype(xs),dims}
|
||||||
|
S = Array{eltype(xs),dims+1}
|
||||||
|
B = Array{eltype(xs),dims+2}
|
||||||
|
Batch{Seq{T,S},B}(xs)
|
||||||
|
end
|
@ -17,15 +17,13 @@ export @net, unroll, unroll1, @shapes,
|
|||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
|
include("Batches/Batches.jl")
|
||||||
|
using .Batches
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
|
||||||
include("model.jl")
|
include("model.jl")
|
||||||
|
|
||||||
include("dims/catmat.jl")
|
|
||||||
include("dims/batching.jl")
|
|
||||||
include("dims/seq.jl")
|
|
||||||
include("dims/iter.jl")
|
|
||||||
|
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
include("compiler/loops.jl")
|
include("compiler/loops.jl")
|
||||||
include("compiler/interp.jl")
|
include("compiler/interp.jl")
|
||||||
|
@ -90,7 +90,7 @@ end
|
|||||||
astensor(model, args...) =
|
astensor(model, args...) =
|
||||||
tograph(model, args...; variables = true)[3]
|
tograph(model, args...; variables = true)[3]
|
||||||
|
|
||||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
RawTensor(data::Union{Flux.Batch,Flux.Seq}) = RawTensor(Flux.rawbatch(data))
|
||||||
|
|
||||||
# Error Handling
|
# Error Handling
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import DataFlow: cse
|
import DataFlow: cse
|
||||||
using MacroTools: @q
|
using MacroTools: @q, @>
|
||||||
|
|
||||||
function graphdef(ex, params = [])
|
function graphdef(ex, params = [])
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
|
@ -22,6 +22,31 @@ end
|
|||||||
|
|
||||||
update!(m::Stateful, η) = update!(m.model, η)
|
update!(m::Stateful, η) = update!(m.model, η)
|
||||||
|
|
||||||
|
# Seq Models
|
||||||
|
|
||||||
|
struct SeqModel
|
||||||
|
model
|
||||||
|
steps::Int
|
||||||
|
end
|
||||||
|
|
||||||
|
runseq(f, xs::Tuple...) = f(xs...)
|
||||||
|
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||||
|
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||||
|
runseq(f, xs) = runseq(f, (xs...,))
|
||||||
|
|
||||||
|
function (m::SeqModel)(x)
|
||||||
|
runseq(x) do x
|
||||||
|
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||||
|
m.model(x)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||||
|
|
||||||
|
update!(m::SeqModel, η) = update!(m.model, η)
|
||||||
|
|
||||||
|
graph(m::SeqModel) = graph(m.model)
|
||||||
|
|
||||||
# Recurrent Graphs
|
# Recurrent Graphs
|
||||||
|
|
||||||
struct Offset
|
struct Offset
|
||||||
|
@ -1,53 +0,0 @@
|
|||||||
export Seq, BatchSeq
|
|
||||||
|
|
||||||
struct Seq{T,S} <: AbstractVector{T}
|
|
||||||
data::CatMat{T,S}
|
|
||||||
end
|
|
||||||
|
|
||||||
@forward Seq.data size, eltype, getindex, setindex!, rawbatch
|
|
||||||
|
|
||||||
Seq(xs) = Seq(CatMat(xs))
|
|
||||||
|
|
||||||
convert{T,S}(::Type{Seq{T,S}},storage::S) =
|
|
||||||
Seq{T,S}(storage)
|
|
||||||
|
|
||||||
@render Juno.Inline b::Seq begin
|
|
||||||
Tree(Row(Text("Seq of "), eltype(b),
|
|
||||||
Juno.fade("[$(length(b))]")),
|
|
||||||
Juno.trim(collect(b)))
|
|
||||||
end
|
|
||||||
|
|
||||||
BatchSeq{T<:Seq} = Batch{T}
|
|
||||||
|
|
||||||
function rebatchseq(xs)
|
|
||||||
dims = ndims(xs)-2
|
|
||||||
T = Array{eltype(xs),dims}
|
|
||||||
S = Array{eltype(xs),dims+1}
|
|
||||||
B = Array{eltype(xs),dims+2}
|
|
||||||
Batch{Seq{T,S},B}(xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
# SeqModel wrapper layer for convenience
|
|
||||||
|
|
||||||
struct SeqModel
|
|
||||||
model
|
|
||||||
steps::Int
|
|
||||||
end
|
|
||||||
|
|
||||||
runseq(f, xs::Tuple...) = f(xs...)
|
|
||||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
|
||||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
|
||||||
runseq(f, xs) = runseq(f, (xs...,))
|
|
||||||
|
|
||||||
function (m::SeqModel)(x)
|
|
||||||
runseq(x) do x
|
|
||||||
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
|
||||||
m.model(x)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
|
||||||
|
|
||||||
update!(m::SeqModel, η) = update!(m.model, η)
|
|
||||||
|
|
||||||
graph(m::SeqModel) = graph(m.model)
|
|
@ -1,4 +1,4 @@
|
|||||||
using Flux, DataFlow, MacroTools, Base.Test
|
using Flux, Flux.Batches, DataFlow, MacroTools, Base.Test
|
||||||
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
|
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
|
||||||
using DataFlow: Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user