remove batching and training

This commit is contained in:
Mike J Innes 2017-08-18 01:04:50 +01:00
parent 5f9d8702a4
commit e79a1657d4
13 changed files with 41 additions and 362 deletions

View File

@ -1,12 +0,0 @@
module Batches
using Juno, Lazy
using Juno: Tree, Row
export Batch, Batched, Seq, rawbatch, batchone
include("catmat.jl")
include("batch.jl")
include("iter.jl")
end

View File

@ -1,40 +0,0 @@
# Batches
struct Batch{T,S} <: Batchable{T,S}
data::Storage{T,S}
Batch{T,S}(data::Storage{T,S}) where {T,S} = new{T,S}(data)
end
Batch(data::Storage{T,S}) where {T,S} = Batch{T,S}(data)
Batch(xs) = Batch(Storage(xs))
Batch{T,S}(xs) where {T,S} = Batch{T,S}(Storage{T,S}(xs))
storage(b::Batch) = b.data
convertel(T::Type, xs::Batch) =
eltype(eltype(xs)) isa T ? xs :
Batch(map(x->convertel(T, x), xs))
batchone(x) = Batch((x,))
batchone(x::Batch) = x
tobatch(xs::Batch) = rawbatch(xs)
tobatch(xs) = tobatch(batchone(xs))
# Sequences
struct Seq{T,S} <: Batchable{T,S}
data::Storage{T,S}
Seq{T,S}(data::Storage{T,S}) where {T,S} = new{T,S}(data)
end
Seq(data::Storage{T,S}) where {T,S} = Seq{T,S}(data)
Seq(xs) = Seq(Storage(xs))
Seq{T,S}(xs) where {T,S} = Seq{T,S}(Storage{T,S}(xs))
storage(s::Seq) = s.data
Base.rpad{T}(xs::Seq{T}, n::Integer, x::T) =
n-length(xs) 0 ? xs : vcat(xs, typeof(xs)(repeated(x, n-length(xs))))

View File

@ -1,100 +0,0 @@
import Base: eltype, size, getindex, setindex!, convert, typename
# Concrete storage
struct Storage{T,S}
data::S
Storage{T,S}(data::S) where {T,S} = new{T,S}(data)
end
allequal(xs) = all(x -> x == first(xs), xs)
function Storage{T,S}(xs, storage::S) where {T, S}
@assert allequal(map(size, xs))
@assert size(storage) == (length(xs), size(first(xs))...)
for i = 1:length(xs)
storage[i, :] = xs[i]
end
return Storage{T,S}(storage)
end
function Storage{T,S}(xs) where {T,S}
xs = map(rawbatch, xs)
storage = S(length(xs), size(first(xs))...)
Storage{T,typeof(storage)}(xs, storage)
end
Storage{T}(xs) where T = Storage{T,diminc(typeof(rawbatch(first(xs))))}(xs)
Storage(xs) = Storage{eltype(xs)}(xs)
convert{T,S}(::Type{Storage{T,S}}, data::S) = Storage{T,S}(data)
convert{T}(::Type{Storage{T}}, data::AbstractArray) = convert(Storage{T,typeof(data)}, data)
# Storage utility methods
rawbatch(xs) = xs
rawbatch(xs::Storage) = xs.data
eltype{T}(::Storage{T}) = T
size(b::Storage) = (size(b.data, 1),)
getindex(b::Storage, i)::eltype(b) = slicedim(b.data, 1, i)
setindex!(b::Storage, v, i::Integer) = b.data[i, :] = v
function setindex!(b::Storage, xs, ::Colon)
for (i, x) in enumerate(xs)
b[i] = x
end
end
# Generic methods
abstract type Batchable{T,S} <: AbstractVector{T}
end
rawbatch(xs::Batchable) = rawbatch(storage(xs))
size(xs::Batchable) = size(storage(xs))
getindex(xs::Batchable, i) = getindex(storage(xs), i)
setindex!(xs::Batchable, v, i...) = setindex!(storage(xs), v, i...)
Base.vcat{T<:Batchable}(xs::T, ys::T)::T = vcat(rawbatch(xs), rawbatch(ys))
typerender(B::Type) = B
typerender(B::Type{<:Batchable}) =
Row(Juno.typ("$(typename(B).name)"), text"{", typerender(eltype(B)), text"}")
@render Juno.Inline b::Batchable begin
Tree(Row(typerender(typeof(b)),
Juno.fade("[$(length(b))]")),
Juno.trim(collect(b)))
end
# Horrible type hacks follow this point
deparam(T::Type) = typename(T).wrapper
diminc(T::Type) = Vector{T}
diminc(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)+1}
dimdec{T}(::Type{<:AbstractArray{T,1}}) = T
dimdec(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)-1}
btype(B::Type, S::Type{<:AbstractArray}) = B
btype(B::Type{<:Batchable}, S::Type{<:AbstractArray}) = B{dimdec(S),S}
btype{T}(B::Type{<:Batchable{T}}, S::Type{<:AbstractArray}) = B{S}
btype{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, ::Type{S}) = B
btype(B::Type{<:Batchable{<:Batchable}}, S::Type{<:AbstractArray}) =
deparam(B){btype(eltype(B), dimdec(S)),S}
convert{T<:Batchable}(::Type{Storage{T}}, data::AbstractArray) =
Storage{btype(T,dimdec(typeof(data))),typeof(data)}(data)
convert{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, data::S) = B(convert(Storage{T,S}, data))
convert{B<:Batchable}(::Type{B}, data::AbstractArray) = convert(btype(B,typeof(data)), data)
convert{B<:Batchable}(::Type{B}, xs::B) = xs

View File

@ -1,85 +0,0 @@
# Simple version
using Base.Iterators: partition
partitionr(xs, n) = take(partition(xs, n), length(xs)÷n)
chunk(xs, n) = (partitionr(xs, length(xs)÷n)...,)
batches(xs, n) = (Batch([xs...]) for xs in partitionr(xs, n))
seqs(xs, n) = (Seq([xs...]) for xs in partitionr(xs, n))
# Stateful iteration
mutable struct StatefulIter{I,S,T}
iter::I
state::S
next::Nullable{T}
end
function StatefulIter(itr)
state = start(itr)
val, state = done(itr, state) ? (Nullable(), state) : next(itr, state)
return StatefulIter(itr, state, convert(Nullable, val))
end
peek(s::StatefulIter) = get(s.next)
function Base.take!(s::StatefulIter)
x = peek(s)
if !done(s.iter, s.state)
s.next, s.state = next(s.iter, s.state)
else
s.next = Nullable()
end
return x
end
Base.isempty(s::StatefulIter) = isnull(s.next)
Base.eltype(s::StatefulIter) = eltype(s.next)
function taken!(s::StatefulIter, n::Integer)
xs = eltype(s)[]
for _ = 1:n
isempty(s) && break
push!(xs, take!(s))
end
return xs
end
# Batched
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
struct Batched{I<:StatefulIter,S}
itr::I
buf::S
end
function Batched(itr, n::Integer)
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
itr = StatefulIter(itr)
x = peek(itr)
buf = convert(Batch{typeof(peek(itr))},
similar(rawbatch(x), n, size(rawbatch(x))...))
Batched(itr, buf)
end
iteratoreltype(::Type{<:Batched}) = Base.HasEltype()
iteratorsize(::Type{<:Batched}) = Base.SizeUnknown()
eltype{T,S}(x::Batched{T,S}) = S
start(::Batched) = ()
next(x::Batched, _) = x.buf, ()
# will be less hacky if https://github.com/JuliaLang/julia/issues/18823
function done(x::Batched, _)
next = taken!(x.itr, length(x.buf))
length(next) < length(x.buf) && return true
for (i, n) in enumerate(next)
x.buf[i] = rawbatch(n)
end
return false
end

View File

@ -16,9 +16,6 @@ export @net, unroll, unroll1, @shapes,
# Zero Flux Given
include("Batches/Batches.jl")
using .Batches
include("core.jl")
import .FluxCore: back!, update!, graph
@ -37,6 +34,5 @@ include("layers/cost.jl")
include("layers/recurrent.jl")
include("data.jl")
include("training.jl")
end # module

View File

@ -43,8 +43,6 @@ seqtuple(xs::AbstractArray, n) =
n 0 && size(xs, 2) n ? error("Expecting sequence length $n, got $(size(xs, 2))") :
(unstack(xs, 2)...)
seqtuple(xs::Batch{<:Seq}, n) = seqtuple(rawbatch(xs), n)
reseq(x) = x
reseq(x::Tuple{}) = ()
reseq(xs::Tuple) = all(isa.(xs, AbstractArray) .& (ndims.(xs) .≥ 2)) ? stack(xs, 2) : reseq.(xs)

View File

@ -1,59 +0,0 @@
using Juno: info
using .Batches: tobatch
"""
Returns a function that when invoked, will only be triggered at most once
during `timeout` seconds. Normally, the throttled function will run
as much as it can, without ever going more than once per `wait` duration;
but if you'd like to disable the execution on the leading edge, pass
`leading=false`. To enable execution on the trailing edge, ditto.
"""
function throttle(f, timeout; leading=true, trailing=false)
cooldown = true
later = nothing
function throttled(args...; kwargs...)
yield()
if cooldown
if leading
f(args...; kwargs...)
else
later = () -> f(args...; kwargs...)
end
cooldown = false
@schedule try
while (sleep(timeout); later != nothing)
later()
later = nothing
end
finally
cooldown = true
end
elseif trailing
later = () -> f(args...; kwargs...)
end
nothing
end
end
function train!(m, train; cb = [],
epoch = 1, η = 0.1, loss = mse)
callback = throttle(()->foreach(f -> f(), cb), 5)
@progress for e in 1:epoch
info("Epoch $e")
for (x, y) in train
x, y = mapt(tobatch, (x, y))
= m(x)
any(isnan, ) && error("NaN")
Δ = back!(loss, 1, , y)
back!(m, Δ, x)
update!(m, η)
callback()
end
end
return m
end

View File

@ -48,3 +48,41 @@ function accuracy(m, data)
end
return correct/n
end
"""
Returns a function that when invoked, will only be triggered at most once
during `timeout` seconds. Normally, the throttled function will run
as much as it can, without ever going more than once per `wait` duration;
but if you'd like to disable the execution on the leading edge, pass
`leading=false`. To enable execution on the trailing edge, ditto.
"""
function throttle(f, timeout; leading=true, trailing=false)
cooldown = true
later = nothing
function throttled(args...; kwargs...)
yield()
if cooldown
if leading
f(args...; kwargs...)
else
later = () -> f(args...; kwargs...)
end
cooldown = false
@schedule try
while (sleep(timeout); later != nothing)
later()
later = nothing
end
finally
cooldown = true
end
elseif trailing
later = () -> f(args...; kwargs...)
end
nothing
end
end

View File

@ -18,7 +18,7 @@ end
function test_recurrence(bk)
@testset "Recurrence" begin
seq = batchone(Seq(rand(10) for i = 1:3))
seq = unsqueeze(stack(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3)
rm = bk(r)
@test r(seq) rm(seq)

View File

@ -1,17 +0,0 @@
using Flux.Batches, Base.Test
@testset "Batching" begin
bs = Batch([[1,2,3],[4,5,6]])
@test bs == [[1,2,3],[4,5,6]]
@test rawbatch(bs) == [1 2 3; 4 5 6]
batchseq = Batch([Seq([[1,2,3],[4,5,6]]),Seq([[7,8,9],[10,11,12]])])
@test batchseq == [[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]]
@test rawbatch(batchseq)[1,1,3] == 3
@test rawbatch(batchseq)[2,2,1] == 10
end

View File

@ -1,38 +0,0 @@
@testset "training julia models" begin
@testset "linear regression" begin
srand(0)
model = Affine(10, 1)
truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]'
data = map(1:256) do i
x = rand(Float32, 10)
x, truth * x + 3rand(Float32)
end
Flux.train!(model, data, epoch=5)
@test cor(reshape.((model.W.x, truth), 10)...) > .99
end
@testset "logistic regression" begin
srand(0)
model = Chain(Affine(10, 1), σ)
truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]'
data = map(1:256) do i
x = rand(Float32, 10)
x, truth * x + 2rand(Float32) > 5f0
end
Flux.train!(model, data, epoch=10)
@test cor(reshape.((model.layers[1].W.x, truth), 10)...) > .99
end
end

View File

@ -13,5 +13,5 @@ end
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
ru = unroll(r, 3)
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
end

View File

@ -1,15 +1,13 @@
using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
using Flux: graph, Param, squeeze, unsqueeze, stack, back!, update!, flatten
using DataFlow: Line, Frame
@testset "Flux" begin
include("batching.jl")
include("backend/common.jl")
include("basic.jl")
include("recurrent.jl")
include("optimizer.jl")
include("throttle.jl")
end