remove batching and training
This commit is contained in:
parent
5f9d8702a4
commit
e79a1657d4
@ -1,12 +0,0 @@
|
||||
module Batches
|
||||
|
||||
using Juno, Lazy
|
||||
using Juno: Tree, Row
|
||||
|
||||
export Batch, Batched, Seq, rawbatch, batchone
|
||||
|
||||
include("catmat.jl")
|
||||
include("batch.jl")
|
||||
include("iter.jl")
|
||||
|
||||
end
|
@ -1,40 +0,0 @@
|
||||
# Batches
|
||||
|
||||
struct Batch{T,S} <: Batchable{T,S}
|
||||
data::Storage{T,S}
|
||||
Batch{T,S}(data::Storage{T,S}) where {T,S} = new{T,S}(data)
|
||||
end
|
||||
|
||||
Batch(data::Storage{T,S}) where {T,S} = Batch{T,S}(data)
|
||||
|
||||
Batch(xs) = Batch(Storage(xs))
|
||||
Batch{T,S}(xs) where {T,S} = Batch{T,S}(Storage{T,S}(xs))
|
||||
|
||||
storage(b::Batch) = b.data
|
||||
|
||||
convertel(T::Type, xs::Batch) =
|
||||
eltype(eltype(xs)) isa T ? xs :
|
||||
Batch(map(x->convertel(T, x), xs))
|
||||
|
||||
batchone(x) = Batch((x,))
|
||||
batchone(x::Batch) = x
|
||||
|
||||
tobatch(xs::Batch) = rawbatch(xs)
|
||||
tobatch(xs) = tobatch(batchone(xs))
|
||||
|
||||
# Sequences
|
||||
|
||||
struct Seq{T,S} <: Batchable{T,S}
|
||||
data::Storage{T,S}
|
||||
Seq{T,S}(data::Storage{T,S}) where {T,S} = new{T,S}(data)
|
||||
end
|
||||
|
||||
Seq(data::Storage{T,S}) where {T,S} = Seq{T,S}(data)
|
||||
|
||||
Seq(xs) = Seq(Storage(xs))
|
||||
Seq{T,S}(xs) where {T,S} = Seq{T,S}(Storage{T,S}(xs))
|
||||
|
||||
storage(s::Seq) = s.data
|
||||
|
||||
Base.rpad{T}(xs::Seq{T}, n::Integer, x::T) =
|
||||
n-length(xs) ≤ 0 ? xs : vcat(xs, typeof(xs)(repeated(x, n-length(xs))))
|
@ -1,100 +0,0 @@
|
||||
import Base: eltype, size, getindex, setindex!, convert, typename
|
||||
|
||||
# Concrete storage
|
||||
|
||||
struct Storage{T,S}
|
||||
data::S
|
||||
Storage{T,S}(data::S) where {T,S} = new{T,S}(data)
|
||||
end
|
||||
|
||||
allequal(xs) = all(x -> x == first(xs), xs)
|
||||
|
||||
function Storage{T,S}(xs, storage::S) where {T, S}
|
||||
@assert allequal(map(size, xs))
|
||||
@assert size(storage) == (length(xs), size(first(xs))...)
|
||||
for i = 1:length(xs)
|
||||
storage[i, :] = xs[i]
|
||||
end
|
||||
return Storage{T,S}(storage)
|
||||
end
|
||||
|
||||
function Storage{T,S}(xs) where {T,S}
|
||||
xs′ = map(rawbatch, xs)
|
||||
storage = S(length(xs′), size(first(xs′))...)
|
||||
Storage{T,typeof(storage)}(xs′, storage)
|
||||
end
|
||||
|
||||
Storage{T}(xs) where T = Storage{T,diminc(typeof(rawbatch(first(xs))))}(xs)
|
||||
|
||||
Storage(xs) = Storage{eltype(xs)}(xs)
|
||||
|
||||
convert{T,S}(::Type{Storage{T,S}}, data::S) = Storage{T,S}(data)
|
||||
|
||||
convert{T}(::Type{Storage{T}}, data::AbstractArray) = convert(Storage{T,typeof(data)}, data)
|
||||
|
||||
# Storage utility methods
|
||||
|
||||
rawbatch(xs) = xs
|
||||
rawbatch(xs::Storage) = xs.data
|
||||
|
||||
eltype{T}(::Storage{T}) = T
|
||||
|
||||
size(b::Storage) = (size(b.data, 1),)
|
||||
|
||||
getindex(b::Storage, i)::eltype(b) = slicedim(b.data, 1, i)
|
||||
|
||||
setindex!(b::Storage, v, i::Integer) = b.data[i, :] = v
|
||||
|
||||
function setindex!(b::Storage, xs, ::Colon)
|
||||
for (i, x) in enumerate(xs)
|
||||
b[i] = x
|
||||
end
|
||||
end
|
||||
|
||||
# Generic methods
|
||||
|
||||
abstract type Batchable{T,S} <: AbstractVector{T}
|
||||
end
|
||||
|
||||
rawbatch(xs::Batchable) = rawbatch(storage(xs))
|
||||
size(xs::Batchable) = size(storage(xs))
|
||||
getindex(xs::Batchable, i) = getindex(storage(xs), i)
|
||||
setindex!(xs::Batchable, v, i...) = setindex!(storage(xs), v, i...)
|
||||
|
||||
Base.vcat{T<:Batchable}(xs::T, ys::T)::T = vcat(rawbatch(xs), rawbatch(ys))
|
||||
|
||||
typerender(B::Type) = B
|
||||
typerender(B::Type{<:Batchable}) =
|
||||
Row(Juno.typ("$(typename(B).name)"), text"{", typerender(eltype(B)), text"}")
|
||||
|
||||
@render Juno.Inline b::Batchable begin
|
||||
Tree(Row(typerender(typeof(b)),
|
||||
Juno.fade("[$(length(b))]")),
|
||||
Juno.trim(collect(b)))
|
||||
end
|
||||
|
||||
# Horrible type hacks follow this point
|
||||
|
||||
deparam(T::Type) = typename(T).wrapper
|
||||
|
||||
diminc(T::Type) = Vector{T}
|
||||
diminc(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)+1}
|
||||
|
||||
dimdec{T}(::Type{<:AbstractArray{T,1}}) = T
|
||||
dimdec(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)-1}
|
||||
|
||||
btype(B::Type, S::Type{<:AbstractArray}) = B
|
||||
btype(B::Type{<:Batchable}, S::Type{<:AbstractArray}) = B{dimdec(S),S}
|
||||
btype{T}(B::Type{<:Batchable{T}}, S::Type{<:AbstractArray}) = B{S}
|
||||
btype{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, ::Type{S}) = B
|
||||
btype(B::Type{<:Batchable{<:Batchable}}, S::Type{<:AbstractArray}) =
|
||||
deparam(B){btype(eltype(B), dimdec(S)),S}
|
||||
|
||||
convert{T<:Batchable}(::Type{Storage{T}}, data::AbstractArray) =
|
||||
Storage{btype(T,dimdec(typeof(data))),typeof(data)}(data)
|
||||
|
||||
convert{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, data::S) = B(convert(Storage{T,S}, data))
|
||||
|
||||
convert{B<:Batchable}(::Type{B}, data::AbstractArray) = convert(btype(B,typeof(data)), data)
|
||||
|
||||
convert{B<:Batchable}(::Type{B}, xs::B) = xs
|
@ -1,85 +0,0 @@
|
||||
# Simple version
|
||||
|
||||
using Base.Iterators: partition
|
||||
|
||||
partitionr(xs, n) = take(partition(xs, n), length(xs)÷n)
|
||||
|
||||
chunk(xs, n) = (partitionr(xs, length(xs)÷n)...,)
|
||||
|
||||
batches(xs, n) = (Batch([xs...]) for xs in partitionr(xs, n))
|
||||
seqs(xs, n) = (Seq([xs...]) for xs in partitionr(xs, n))
|
||||
|
||||
# Stateful iteration
|
||||
|
||||
mutable struct StatefulIter{I,S,T}
|
||||
iter::I
|
||||
state::S
|
||||
next::Nullable{T}
|
||||
end
|
||||
|
||||
function StatefulIter(itr)
|
||||
state = start(itr)
|
||||
val, state = done(itr, state) ? (Nullable(), state) : next(itr, state)
|
||||
return StatefulIter(itr, state, convert(Nullable, val))
|
||||
end
|
||||
|
||||
peek(s::StatefulIter) = get(s.next)
|
||||
|
||||
function Base.take!(s::StatefulIter)
|
||||
x = peek(s)
|
||||
if !done(s.iter, s.state)
|
||||
s.next, s.state = next(s.iter, s.state)
|
||||
else
|
||||
s.next = Nullable()
|
||||
end
|
||||
return x
|
||||
end
|
||||
|
||||
Base.isempty(s::StatefulIter) = isnull(s.next)
|
||||
Base.eltype(s::StatefulIter) = eltype(s.next)
|
||||
|
||||
function taken!(s::StatefulIter, n::Integer)
|
||||
xs = eltype(s)[]
|
||||
for _ = 1:n
|
||||
isempty(s) && break
|
||||
push!(xs, take!(s))
|
||||
end
|
||||
return xs
|
||||
end
|
||||
|
||||
# Batched
|
||||
|
||||
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
||||
|
||||
struct Batched{I<:StatefulIter,S}
|
||||
itr::I
|
||||
buf::S
|
||||
end
|
||||
|
||||
function Batched(itr, n::Integer)
|
||||
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
|
||||
itr = StatefulIter(itr)
|
||||
x = peek(itr)
|
||||
buf = convert(Batch{typeof(peek(itr))},
|
||||
similar(rawbatch(x), n, size(rawbatch(x))...))
|
||||
Batched(itr, buf)
|
||||
end
|
||||
|
||||
iteratoreltype(::Type{<:Batched}) = Base.HasEltype()
|
||||
iteratorsize(::Type{<:Batched}) = Base.SizeUnknown()
|
||||
|
||||
eltype{T,S}(x::Batched{T,S}) = S
|
||||
|
||||
start(::Batched) = ()
|
||||
|
||||
next(x::Batched, _) = x.buf, ()
|
||||
|
||||
# will be less hacky if https://github.com/JuliaLang/julia/issues/18823
|
||||
function done(x::Batched, _)
|
||||
next = taken!(x.itr, length(x.buf))
|
||||
length(next) < length(x.buf) && return true
|
||||
for (i, n) in enumerate(next)
|
||||
x.buf[i] = rawbatch(n)
|
||||
end
|
||||
return false
|
||||
end
|
@ -16,9 +16,6 @@ export @net, unroll, unroll1, @shapes,
|
||||
|
||||
# Zero Flux Given
|
||||
|
||||
include("Batches/Batches.jl")
|
||||
using .Batches
|
||||
|
||||
include("core.jl")
|
||||
import .FluxCore: back!, update!, graph
|
||||
|
||||
@ -37,6 +34,5 @@ include("layers/cost.jl")
|
||||
include("layers/recurrent.jl")
|
||||
|
||||
include("data.jl")
|
||||
include("training.jl")
|
||||
|
||||
end # module
|
||||
|
@ -43,8 +43,6 @@ seqtuple(xs::AbstractArray, n) =
|
||||
n ≠ 0 && size(xs, 2) ≠ n ? error("Expecting sequence length $n, got $(size(xs, 2))") :
|
||||
(unstack(xs, 2)...)
|
||||
|
||||
seqtuple(xs::Batch{<:Seq}, n) = seqtuple(rawbatch(xs), n)
|
||||
|
||||
reseq(x) = x
|
||||
reseq(x::Tuple{}) = ()
|
||||
reseq(xs::Tuple) = all(isa.(xs, AbstractArray) .& (ndims.(xs) .≥ 2)) ? stack(xs, 2) : reseq.(xs)
|
||||
|
@ -1,59 +0,0 @@
|
||||
using Juno: info
|
||||
using .Batches: tobatch
|
||||
|
||||
"""
|
||||
Returns a function that when invoked, will only be triggered at most once
|
||||
during `timeout` seconds. Normally, the throttled function will run
|
||||
as much as it can, without ever going more than once per `wait` duration;
|
||||
but if you'd like to disable the execution on the leading edge, pass
|
||||
`leading=false`. To enable execution on the trailing edge, ditto.
|
||||
"""
|
||||
function throttle(f, timeout; leading=true, trailing=false)
|
||||
cooldown = true
|
||||
later = nothing
|
||||
|
||||
function throttled(args...; kwargs...)
|
||||
yield()
|
||||
|
||||
if cooldown
|
||||
if leading
|
||||
f(args...; kwargs...)
|
||||
else
|
||||
later = () -> f(args...; kwargs...)
|
||||
end
|
||||
|
||||
cooldown = false
|
||||
@schedule try
|
||||
while (sleep(timeout); later != nothing)
|
||||
later()
|
||||
later = nothing
|
||||
end
|
||||
finally
|
||||
cooldown = true
|
||||
end
|
||||
elseif trailing
|
||||
later = () -> f(args...; kwargs...)
|
||||
end
|
||||
|
||||
nothing
|
||||
end
|
||||
end
|
||||
|
||||
function train!(m, train; cb = [],
|
||||
epoch = 1, η = 0.1, loss = mse)
|
||||
callback = throttle(()->foreach(f -> f(), cb), 5)
|
||||
|
||||
@progress for e in 1:epoch
|
||||
info("Epoch $e")
|
||||
for (x, y) in train
|
||||
x, y = mapt(tobatch, (x, y))
|
||||
ŷ = m(x)
|
||||
any(isnan, ŷ) && error("NaN")
|
||||
Δ = back!(loss, 1, ŷ, y)
|
||||
back!(m, Δ, x)
|
||||
update!(m, η)
|
||||
callback()
|
||||
end
|
||||
end
|
||||
return m
|
||||
end
|
38
src/utils.jl
38
src/utils.jl
@ -48,3 +48,41 @@ function accuracy(m, data)
|
||||
end
|
||||
return correct/n
|
||||
end
|
||||
|
||||
"""
|
||||
Returns a function that when invoked, will only be triggered at most once
|
||||
during `timeout` seconds. Normally, the throttled function will run
|
||||
as much as it can, without ever going more than once per `wait` duration;
|
||||
but if you'd like to disable the execution on the leading edge, pass
|
||||
`leading=false`. To enable execution on the trailing edge, ditto.
|
||||
"""
|
||||
function throttle(f, timeout; leading=true, trailing=false)
|
||||
cooldown = true
|
||||
later = nothing
|
||||
|
||||
function throttled(args...; kwargs...)
|
||||
yield()
|
||||
|
||||
if cooldown
|
||||
if leading
|
||||
f(args...; kwargs...)
|
||||
else
|
||||
later = () -> f(args...; kwargs...)
|
||||
end
|
||||
|
||||
cooldown = false
|
||||
@schedule try
|
||||
while (sleep(timeout); later != nothing)
|
||||
later()
|
||||
later = nothing
|
||||
end
|
||||
finally
|
||||
cooldown = true
|
||||
end
|
||||
elseif trailing
|
||||
later = () -> f(args...; kwargs...)
|
||||
end
|
||||
|
||||
nothing
|
||||
end
|
||||
end
|
||||
|
@ -18,7 +18,7 @@ end
|
||||
|
||||
function test_recurrence(bk)
|
||||
@testset "Recurrence" begin
|
||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||
seq = unsqueeze(stack(rand(10) for i = 1:3))
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
rm = bk(r)
|
||||
@test r(seq) ≈ rm(seq)
|
||||
|
@ -1,17 +0,0 @@
|
||||
using Flux.Batches, Base.Test
|
||||
|
||||
@testset "Batching" begin
|
||||
|
||||
bs = Batch([[1,2,3],[4,5,6]])
|
||||
|
||||
@test bs == [[1,2,3],[4,5,6]]
|
||||
|
||||
@test rawbatch(bs) == [1 2 3; 4 5 6]
|
||||
|
||||
batchseq = Batch([Seq([[1,2,3],[4,5,6]]),Seq([[7,8,9],[10,11,12]])])
|
||||
|
||||
@test batchseq == [[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]]
|
||||
@test rawbatch(batchseq)[1,1,3] == 3
|
||||
@test rawbatch(batchseq)[2,2,1] == 10
|
||||
|
||||
end
|
@ -1,38 +0,0 @@
|
||||
@testset "training julia models" begin
|
||||
|
||||
@testset "linear regression" begin
|
||||
srand(0)
|
||||
|
||||
model = Affine(10, 1)
|
||||
|
||||
truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]'
|
||||
|
||||
data = map(1:256) do i
|
||||
x = rand(Float32, 10)
|
||||
x, truth * x + 3rand(Float32)
|
||||
end
|
||||
|
||||
Flux.train!(model, data, epoch=5)
|
||||
|
||||
@test cor(reshape.((model.W.x, truth), 10)...) > .99
|
||||
end
|
||||
|
||||
@testset "logistic regression" begin
|
||||
srand(0)
|
||||
|
||||
model = Chain(Affine(10, 1), σ)
|
||||
|
||||
truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]'
|
||||
|
||||
data = map(1:256) do i
|
||||
x = rand(Float32, 10)
|
||||
x, truth * x + 2rand(Float32) > 5f0
|
||||
end
|
||||
|
||||
Flux.train!(model, data, epoch=10)
|
||||
|
||||
@test cor(reshape.((model.layers[1].W.x, truth), 10)...) > .99
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -13,5 +13,5 @@ end
|
||||
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
|
||||
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
|
||||
ru = unroll(r, 3)
|
||||
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
|
||||
ru(unsqueeze(stack(squeeze.(xs))))[1] == squeeze.(ys)
|
||||
end
|
||||
|
@ -1,15 +1,13 @@
|
||||
using Flux, DataFlow, MacroTools, Base.Test
|
||||
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
|
||||
using Flux: graph, Param, squeeze, unsqueeze, stack, back!, update!, flatten
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
@testset "Flux" begin
|
||||
|
||||
include("batching.jl")
|
||||
include("backend/common.jl")
|
||||
|
||||
include("basic.jl")
|
||||
include("recurrent.jl")
|
||||
include("optimizer.jl")
|
||||
include("throttle.jl")
|
||||
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user