diff --git a/src/Batches/Batches.jl b/src/Batches/Batches.jl deleted file mode 100644 index 433d0bdf..00000000 --- a/src/Batches/Batches.jl +++ /dev/null @@ -1,12 +0,0 @@ -module Batches - -using Juno, Lazy -using Juno: Tree, Row - -export Batch, Batched, Seq, rawbatch, batchone - -include("catmat.jl") -include("batch.jl") -include("iter.jl") - -end diff --git a/src/Batches/batch.jl b/src/Batches/batch.jl deleted file mode 100644 index 2cb82239..00000000 --- a/src/Batches/batch.jl +++ /dev/null @@ -1,40 +0,0 @@ -# Batches - -struct Batch{T,S} <: Batchable{T,S} - data::Storage{T,S} - Batch{T,S}(data::Storage{T,S}) where {T,S} = new{T,S}(data) -end - -Batch(data::Storage{T,S}) where {T,S} = Batch{T,S}(data) - -Batch(xs) = Batch(Storage(xs)) -Batch{T,S}(xs) where {T,S} = Batch{T,S}(Storage{T,S}(xs)) - -storage(b::Batch) = b.data - -convertel(T::Type, xs::Batch) = - eltype(eltype(xs)) isa T ? xs : - Batch(map(x->convertel(T, x), xs)) - -batchone(x) = Batch((x,)) -batchone(x::Batch) = x - -tobatch(xs::Batch) = rawbatch(xs) -tobatch(xs) = tobatch(batchone(xs)) - -# Sequences - -struct Seq{T,S} <: Batchable{T,S} - data::Storage{T,S} - Seq{T,S}(data::Storage{T,S}) where {T,S} = new{T,S}(data) -end - -Seq(data::Storage{T,S}) where {T,S} = Seq{T,S}(data) - -Seq(xs) = Seq(Storage(xs)) -Seq{T,S}(xs) where {T,S} = Seq{T,S}(Storage{T,S}(xs)) - -storage(s::Seq) = s.data - -Base.rpad{T}(xs::Seq{T}, n::Integer, x::T) = - n-length(xs) ≤ 0 ? xs : vcat(xs, typeof(xs)(repeated(x, n-length(xs)))) diff --git a/src/Batches/catmat.jl b/src/Batches/catmat.jl deleted file mode 100644 index dcf336c1..00000000 --- a/src/Batches/catmat.jl +++ /dev/null @@ -1,100 +0,0 @@ -import Base: eltype, size, getindex, setindex!, convert, typename - -# Concrete storage - -struct Storage{T,S} - data::S - Storage{T,S}(data::S) where {T,S} = new{T,S}(data) -end - -allequal(xs) = all(x -> x == first(xs), xs) - -function Storage{T,S}(xs, storage::S) where {T, S} - @assert allequal(map(size, xs)) - @assert size(storage) == (length(xs), size(first(xs))...) - for i = 1:length(xs) - storage[i, :] = xs[i] - end - return Storage{T,S}(storage) -end - -function Storage{T,S}(xs) where {T,S} - xs′ = map(rawbatch, xs) - storage = S(length(xs′), size(first(xs′))...) - Storage{T,typeof(storage)}(xs′, storage) -end - -Storage{T}(xs) where T = Storage{T,diminc(typeof(rawbatch(first(xs))))}(xs) - -Storage(xs) = Storage{eltype(xs)}(xs) - -convert{T,S}(::Type{Storage{T,S}}, data::S) = Storage{T,S}(data) - -convert{T}(::Type{Storage{T}}, data::AbstractArray) = convert(Storage{T,typeof(data)}, data) - -# Storage utility methods - -rawbatch(xs) = xs -rawbatch(xs::Storage) = xs.data - -eltype{T}(::Storage{T}) = T - -size(b::Storage) = (size(b.data, 1),) - -getindex(b::Storage, i)::eltype(b) = slicedim(b.data, 1, i) - -setindex!(b::Storage, v, i::Integer) = b.data[i, :] = v - -function setindex!(b::Storage, xs, ::Colon) - for (i, x) in enumerate(xs) - b[i] = x - end -end - -# Generic methods - -abstract type Batchable{T,S} <: AbstractVector{T} -end - -rawbatch(xs::Batchable) = rawbatch(storage(xs)) -size(xs::Batchable) = size(storage(xs)) -getindex(xs::Batchable, i) = getindex(storage(xs), i) -setindex!(xs::Batchable, v, i...) = setindex!(storage(xs), v, i...) - -Base.vcat{T<:Batchable}(xs::T, ys::T)::T = vcat(rawbatch(xs), rawbatch(ys)) - -typerender(B::Type) = B -typerender(B::Type{<:Batchable}) = - Row(Juno.typ("$(typename(B).name)"), text"{", typerender(eltype(B)), text"}") - -@render Juno.Inline b::Batchable begin - Tree(Row(typerender(typeof(b)), - Juno.fade("[$(length(b))]")), - Juno.trim(collect(b))) -end - -# Horrible type hacks follow this point - -deparam(T::Type) = typename(T).wrapper - -diminc(T::Type) = Vector{T} -diminc(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)+1} - -dimdec{T}(::Type{<:AbstractArray{T,1}}) = T -dimdec(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)-1} - -btype(B::Type, S::Type{<:AbstractArray}) = B -btype(B::Type{<:Batchable}, S::Type{<:AbstractArray}) = B{dimdec(S),S} -btype{T}(B::Type{<:Batchable{T}}, S::Type{<:AbstractArray}) = B{S} -btype{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, ::Type{S}) = B -btype(B::Type{<:Batchable{<:Batchable}}, S::Type{<:AbstractArray}) = - deparam(B){btype(eltype(B), dimdec(S)),S} - -convert{T<:Batchable}(::Type{Storage{T}}, data::AbstractArray) = - Storage{btype(T,dimdec(typeof(data))),typeof(data)}(data) - -convert{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, data::S) = B(convert(Storage{T,S}, data)) - -convert{B<:Batchable}(::Type{B}, data::AbstractArray) = convert(btype(B,typeof(data)), data) - -convert{B<:Batchable}(::Type{B}, xs::B) = xs diff --git a/src/Batches/iter.jl b/src/Batches/iter.jl deleted file mode 100644 index 885c9318..00000000 --- a/src/Batches/iter.jl +++ /dev/null @@ -1,85 +0,0 @@ -# Simple version - -using Base.Iterators: partition - -partitionr(xs, n) = take(partition(xs, n), length(xs)÷n) - -chunk(xs, n) = (partitionr(xs, length(xs)÷n)...,) - -batches(xs, n) = (Batch([xs...]) for xs in partitionr(xs, n)) -seqs(xs, n) = (Seq([xs...]) for xs in partitionr(xs, n)) - -# Stateful iteration - -mutable struct StatefulIter{I,S,T} - iter::I - state::S - next::Nullable{T} -end - -function StatefulIter(itr) - state = start(itr) - val, state = done(itr, state) ? (Nullable(), state) : next(itr, state) - return StatefulIter(itr, state, convert(Nullable, val)) -end - -peek(s::StatefulIter) = get(s.next) - -function Base.take!(s::StatefulIter) - x = peek(s) - if !done(s.iter, s.state) - s.next, s.state = next(s.iter, s.state) - else - s.next = Nullable() - end - return x -end - -Base.isempty(s::StatefulIter) = isnull(s.next) -Base.eltype(s::StatefulIter) = eltype(s.next) - -function taken!(s::StatefulIter, n::Integer) - xs = eltype(s)[] - for _ = 1:n - isempty(s) && break - push!(xs, take!(s)) - end - return xs -end - -# Batched - -import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length - -struct Batched{I<:StatefulIter,S} - itr::I - buf::S -end - -function Batched(itr, n::Integer) - n >= 1 || throw(ArgumentError("batch size must be >= 1")) - itr = StatefulIter(itr) - x = peek(itr) - buf = convert(Batch{typeof(peek(itr))}, - similar(rawbatch(x), n, size(rawbatch(x))...)) - Batched(itr, buf) -end - -iteratoreltype(::Type{<:Batched}) = Base.HasEltype() -iteratorsize(::Type{<:Batched}) = Base.SizeUnknown() - -eltype{T,S}(x::Batched{T,S}) = S - -start(::Batched) = () - -next(x::Batched, _) = x.buf, () - -# will be less hacky if https://github.com/JuliaLang/julia/issues/18823 -function done(x::Batched, _) - next = taken!(x.itr, length(x.buf)) - length(next) < length(x.buf) && return true - for (i, n) in enumerate(next) - x.buf[i] = rawbatch(n) - end - return false -end diff --git a/src/Flux.jl b/src/Flux.jl index 794f3e4e..c42188f5 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -16,9 +16,6 @@ export @net, unroll, unroll1, @shapes, # Zero Flux Given -include("Batches/Batches.jl") -using .Batches - include("core.jl") import .FluxCore: back!, update!, graph @@ -37,6 +34,5 @@ include("layers/cost.jl") include("layers/recurrent.jl") include("data.jl") -include("training.jl") end # module diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 0df64225..f31bc113 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -43,8 +43,6 @@ seqtuple(xs::AbstractArray, n) = n ≠ 0 && size(xs, 2) ≠ n ? error("Expecting sequence length $n, got $(size(xs, 2))") : (unstack(xs, 2)...) -seqtuple(xs::Batch{<:Seq}, n) = seqtuple(rawbatch(xs), n) - reseq(x) = x reseq(x::Tuple{}) = () reseq(xs::Tuple) = all(isa.(xs, AbstractArray) .& (ndims.(xs) .≥ 2)) ? stack(xs, 2) : reseq.(xs) diff --git a/src/training.jl b/src/training.jl deleted file mode 100644 index 780303d8..00000000 --- a/src/training.jl +++ /dev/null @@ -1,59 +0,0 @@ -using Juno: info -using .Batches: tobatch - -""" -Returns a function that when invoked, will only be triggered at most once -during `timeout` seconds. Normally, the throttled function will run -as much as it can, without ever going more than once per `wait` duration; -but if you'd like to disable the execution on the leading edge, pass -`leading=false`. To enable execution on the trailing edge, ditto. -""" -function throttle(f, timeout; leading=true, trailing=false) - cooldown = true - later = nothing - - function throttled(args...; kwargs...) - yield() - - if cooldown - if leading - f(args...; kwargs...) - else - later = () -> f(args...; kwargs...) - end - - cooldown = false - @schedule try - while (sleep(timeout); later != nothing) - later() - later = nothing - end - finally - cooldown = true - end - elseif trailing - later = () -> f(args...; kwargs...) - end - - nothing - end -end - -function train!(m, train; cb = [], - epoch = 1, η = 0.1, loss = mse) - callback = throttle(()->foreach(f -> f(), cb), 5) - - @progress for e in 1:epoch - info("Epoch $e") - for (x, y) in train - x, y = mapt(tobatch, (x, y)) - ŷ = m(x) - any(isnan, ŷ) && error("NaN") - Δ = back!(loss, 1, ŷ, y) - back!(m, Δ, x) - update!(m, η) - callback() - end - end - return m -end diff --git a/src/utils.jl b/src/utils.jl index 885b2ccf..c77202e1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -48,3 +48,41 @@ function accuracy(m, data) end return correct/n end + +""" +Returns a function that when invoked, will only be triggered at most once +during `timeout` seconds. Normally, the throttled function will run +as much as it can, without ever going more than once per `wait` duration; +but if you'd like to disable the execution on the leading edge, pass +`leading=false`. To enable execution on the trailing edge, ditto. +""" +function throttle(f, timeout; leading=true, trailing=false) + cooldown = true + later = nothing + + function throttled(args...; kwargs...) + yield() + + if cooldown + if leading + f(args...; kwargs...) + else + later = () -> f(args...; kwargs...) + end + + cooldown = false + @schedule try + while (sleep(timeout); later != nothing) + later() + later = nothing + end + finally + cooldown = true + end + elseif trailing + later = () -> f(args...; kwargs...) + end + + nothing + end +end diff --git a/test/backend/common.jl b/test/backend/common.jl index 7a7aebaa..23f5ed46 100644 --- a/test/backend/common.jl +++ b/test/backend/common.jl @@ -18,7 +18,7 @@ end function test_recurrence(bk) @testset "Recurrence" begin - seq = batchone(Seq(rand(10) for i = 1:3)) + seq = unsqueeze(stack(rand(10) for i = 1:3)) r = unroll(Recurrent(10, 5), 3) rm = bk(r) @test r(seq) ≈ rm(seq) diff --git a/test/batching.jl b/test/batching.jl deleted file mode 100644 index 222c1b79..00000000 --- a/test/batching.jl +++ /dev/null @@ -1,17 +0,0 @@ -using Flux.Batches, Base.Test - -@testset "Batching" begin - -bs = Batch([[1,2,3],[4,5,6]]) - -@test bs == [[1,2,3],[4,5,6]] - -@test rawbatch(bs) == [1 2 3; 4 5 6] - -batchseq = Batch([Seq([[1,2,3],[4,5,6]]),Seq([[7,8,9],[10,11,12]])]) - -@test batchseq == [[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]] -@test rawbatch(batchseq)[1,1,3] == 3 -@test rawbatch(batchseq)[2,2,1] == 10 - -end diff --git a/test/optimizer.jl b/test/optimizer.jl deleted file mode 100644 index 57f1d011..00000000 --- a/test/optimizer.jl +++ /dev/null @@ -1,38 +0,0 @@ -@testset "training julia models" begin - - @testset "linear regression" begin - srand(0) - - model = Affine(10, 1) - - truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' - - data = map(1:256) do i - x = rand(Float32, 10) - x, truth * x + 3rand(Float32) - end - - Flux.train!(model, data, epoch=5) - - @test cor(reshape.((model.W.x, truth), 10)...) > .99 - end - - @testset "logistic regression" begin - srand(0) - - model = Chain(Affine(10, 1), σ) - - truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' - - data = map(1:256) do i - x = rand(Float32, 10) - x, truth * x + 2rand(Float32) > 5f0 - end - - Flux.train!(model, data, epoch=10) - - @test cor(reshape.((model.layers[1].W.x, truth), 10)...) > .99 - end - -end - diff --git a/test/recurrent.jl b/test/recurrent.jl index 3f04d6c4..236a5b59 100644 --- a/test/recurrent.jl +++ b/test/recurrent.jl @@ -13,5 +13,5 @@ end _, ys = apply(unroll1(r).model, xs, (r.y.x,)) @test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x) ru = unroll(r, 3) - ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys) + ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys) end diff --git a/test/runtests.jl b/test/runtests.jl index 08e1ea9a..b101fa12 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,15 +1,13 @@ using Flux, DataFlow, MacroTools, Base.Test -using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten +using Flux: graph, Param, squeeze, unsqueeze, stack, back!, update!, flatten using DataFlow: Line, Frame @testset "Flux" begin -include("batching.jl") include("backend/common.jl") include("basic.jl") include("recurrent.jl") -include("optimizer.jl") include("throttle.jl") end