Merge branch 'master' into patch-3

This commit is contained in:
Manjunath Bhat 2019-03-07 23:08:40 +05:30 committed by GitHub
commit 47c1324476
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 337 additions and 100 deletions

19
NEWS.md Normal file
View File

@ -0,0 +1,19 @@
# v0.8.0
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/).
* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592)
* Many docs and bugfixes thanks to @KristofferC and others.
* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`.
* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615).
* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs.
AD Changes:
* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files).
* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576).
# v0.7.0
Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time.

View File

@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
params, mapleaves, cpu, gpu, f32, f64 params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib @reexport using NNlib

View File

@ -144,34 +144,32 @@ BatchNorm(chs::Integer, λ = identity;
function (BN::BatchNorm)(x) function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) || size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = length(size(x)) dims = length(size(x))
channels = size(x, dims-1) channels = size(x, dims-1)
affine_shape = ones(Int, dims) affine_shape = ones(Int, dims)
affine_shape[end-1] = channels affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end] m = prod(size(x)[1:end-2]) * size(x)[end]
γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...)
if !BN.active if !BN.active
μ = reshape(BN.μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...) σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ
else else
T = eltype(x) T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes) μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
ϵ = data(convert(T, BN.ϵ))
# update moving mean/std # update moving mean/std
mtm = data(convert(T, BN.momentum)) mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)) BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
end end
let λ = BN.λ let λ = BN.λ
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) = (x .- μ) ./ sqrt.(σ² .+ ϵ)
# This is intentionally not fused because of an extreme slowdown doing so λ.(γ .* .+ β)
λ.(temp .+ reshape(β, affine_shape...))
end end
end end
@ -188,3 +186,103 @@ function Base.show(io::IO, l::BatchNorm)
(l.λ == identity) || print(io, ", λ = $(l.λ)") (l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")") print(io, ")")
end end
"""
InstanceNorm(channels::Integer, σ = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
Instance Normalization layer. The `channels` input should be the size of the
channel dimension in your data (see below).
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
a batch of feature vectors this is just the data dimension, for `WHCN` images
it's the usual channel dimension.)
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
shifts them to have a new mean and variance (corresponding to the learnable,
per-channel `bias` and `scale` parameters).
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
Example:
```julia
m = Chain(
Dense(28^2, 64),
InstanceNorm(64, relu),
Dense(64, 10),
InstanceNorm(10),
softmax)
```
"""
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
mutable struct InstanceNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
ndims(x) > 2 ||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
# these are repeated later on depending on the batch size
dims = length(size(x))
c = size(x, dims-1)
bs = size(x, dims)
affine_shape = ones(Int, dims)
affine_shape[end-1] = c
affine_shape[end] = bs
m = prod(size(x)[1:end-2])
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
if !in.active
μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ
else
T = eltype(x)
ϵ = data(convert(T, in.ϵ))
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
μ = mean(x, dims = axes)
σ² = mean((x .- μ) .^ 2, dims = axes)
# update moving mean/std
mtm = data(convert(T, in.momentum))
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
end
let λ = in.λ
= (x .- μ) ./ sqrt.(σ² .+ ϵ)
λ.(γ .* .+ β)
end
end
children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
_testmode!(in::InstanceNorm, test) = (in.active = !test)
function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end

View File

@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
function apply!(o::Momentum, x, Δ) function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x) v = get!(o.velocity, x, zero(x))::typeof(data(x))
@. v = ρ * v - η * Δ @. v = ρ * v - η * Δ
@. Δ = -v @. Δ = -v
end end
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
function apply!(o::Nesterov, x, Δ) function apply!(o::Nesterov, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x) v = get!(o.velocity, x, zero(x))::typeof(data(x))
d = @. ρ^2 * v - (1+ρ) * η * Δ d = @. ρ^2 * v - (1+ρ) * η * Δ
@. v = ρ*v - η*Δ @. v = ρ*v - η*Δ
@. Δ = -d @. Δ = -d
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function apply!(o::RMSProp, x, Δ) function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho η, ρ = o.eta, o.rho
acc = get!(o.acc, x, zero(x))::typeof(x) acc = get!(o.acc, x, zero(x))::typeof(data(x))
@. acc = ρ * acc + (1 - ρ) * Δ^2 @. acc = ρ * acc + (1 - ρ) * Δ^2
@. Δ *= η / (acc + ϵ) @. Δ *= η / (acc + ϵ)
end end
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
function apply!(o::ADAGrad, x, Δ) function apply!(o::ADAGrad, x, Δ)
η = o.eta η = o.eta
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
@. acc += Δ^2 @. acc += Δ^2
@. Δ *= η / (acc + ϵ) @. Δ *= η / (acc + ϵ)
end end
@ -323,5 +323,5 @@ WeightDecay() = WeightDecay(0)
function apply!(o::WeightDecay, x, Δ) function apply!(o::WeightDecay, x, Δ)
wd = o.wd wd = o.wd
@. Δ += wd * x @. Δ += wd * data(x)
end end

View File

@ -1,16 +1,23 @@
using Juno using Juno
import Flux.Tracker: data, grad, back!, update! import Flux.Tracker: Params, gradient, data, update!
import Base.depwarn import Base.depwarn
function update!(opt, x, ) function update!(opt, x, )
update!(x, apply!(opt, x, copy(data()))) update!(x, -apply!(opt, x, data()))
end end
function _update_params!(opt, xs) function update!(opt, xs::Params, gs)
for x in xs for x in xs
Δ = apply!(opt, x.data, x.grad) update!(opt, x, gs[x])
x.data .-= Δ end
Δ .= 0 end
# Added as an internal API but everyone started using it.
function _update_params!(opt, xs)
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
for x in xs
update!(opt, x, Tracker.grad(x))
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
end end
end end
@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
struct StopException <: Exception end struct StopException <: Exception end
""" """
stop() stop()
@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
function train!(loss, ps, data, opt; cb = () -> ()) function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb) cb = runall(cb)
opt = runall(opt)
@progress for d in data @progress for d in data
try try
l = loss(d...) gs = gradient(ps) do
@interrupts back!(l) loss(d...)
_update_params!(opt, ps) end
update!(opt, ps, gs)
if cb() == :stop if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break break

View File

@ -62,6 +62,7 @@ macro grad(ex)
end end
include("idset.jl") include("idset.jl")
include("params.jl")
include("back.jl") include("back.jl")
include("numeric.jl") include("numeric.jl")
include("lib/real.jl") include("lib/real.jl")

View File

@ -1,3 +1,15 @@
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
# In-place gradients
init_grad(x) = zero(x) init_grad(x) = zero(x)
zero_grad!(x) = zero(x) zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0) zero_grad!(x::AbstractArray) = (x .= 0)
@ -66,64 +78,34 @@ function back!(x, Δ; once = true)
return return
end end
function extract_grad!(x)
= copy(grad(x))
= nobacksies("Use `gradient(...; nest = true)` for nested derivatives", )
tracker(x).grad = zero_grad!(grad(x))
return
end
function gradient_(f, xs...) function gradient_(f, xs...)
xs = param.(data.(xs)) xs = param.(data.(xs))
l = f(xs...) l = f(xs...)
losscheck(l) losscheck(l)
back!(l) @interrupts back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives", extract_grad!.(xs)
grad.(xs)) end
function gradient_(f, xs::Params)
l = f()
losscheck(l)
@interrupts back!(l)
gs = Grads()
for x in xs
gs[tracker(x)] = extract_grad!(x)
end
return gs
end end
# Out-of-place gradients # Out-of-place gradients
struct Params
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.order, ", ")
print(io, "])")
end
struct Grads
grads::IdDict{Any,Any}
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ) function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || (Δs isa Tuple && length(Δs) >= length(c.args)) ||
@ -182,8 +164,6 @@ end
gradient(f, xs...; nest = false) = gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...) nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)
# Jacobians and Hessians # Jacobians and Hessians
import ..Flux import ..Flux

View File

@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ)
return x return x
end end
function update!(x::AbstractArray, Δ)
x .+= data(Δ)
return x
end
# Fallthrough methods # Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args for f in :[Base.size, Base.ndims, Base.collect].args

46
src/tracker/params.jl Normal file
View File

@ -0,0 +1,46 @@
struct Params
order::Vector{Any}
params::IdSet{Any}
Params() = new([], IdSet())
end
@forward Params.order Base.iterate, Base.length
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
end
return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
Params(xs) = push!(Params(), xs...)
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.order, ", ")
print(io, "])")
end
struct Grads
grads::IdDict{Any,Any}
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ

View File

@ -104,3 +104,99 @@ end
@test (@allocated m(x)) < 100_000_000 @test (@allocated m(x)) < 100_000_000
end end
end end
@testset "InstanceNorm" begin
# helper functions
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m.β.data == [0, 0] # initβ(2)
@test m.γ.data == [1, 1] # initγ(2)
@test m.active
m(x)
#julia> x
#[:, :, 1] =
# 1.0 4.0
# 2.0 5.0
# 3.0 6.0
#
#[:, :, 2] =
# 7.0 10.0
# 8.0 11.0
# 9.0 12.0
#
# μ will be
# (1. + 2. + 3.) / 3 = 2.
# (4. + 5. + 6.) / 3 = 5.
#
# (7. + 8. + 9.) / 3 = 8.
# (10. + 11. + 12.) / 3 = 11.
#
# ∴ update rule with momentum:
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
@test m.μ [0.5, 0.8]
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
# 2-element Array{Float64,1}:
# 1.
# 1.
@test m.σ² reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
testmode!(m)
@test !m.active
x = m(x).data
@test isapprox(x[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
end
# with activation function
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
x = param(reshape(collect(1:prod(sizes)), sizes))
affine_shape = collect(sizes)
affine_shape[1] = 1
@test m.active
m(x)
testmode!(m)
@test !m.active
y = m(x).data
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
end
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
y = reshape(m(y), sizes...)
@test m(x) == y
end
# check that μ, σ², and the output are the correct size for higher rank tensors
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
y = m(x)
@test size(m.μ) == (sizes[end - 1], )
@test size(m.σ²) == (sizes[end - 1], )
@test size(y) == sizes
end
# show that instance norm is equal to batch norm when channel and batch dims are squashed
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
x = param(reshape(collect(1:prod(sizes)), sizes))
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
end
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
m(x)
@test (@allocated m(x)) < 100_000_000
end
end

View File

@ -4,21 +4,15 @@ using Flux.Tracker
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum] @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
Momentum()]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001)
if opt isa Descent || opt isa ADAGrad
opt = Opt(0.1)
end
if opt isa ADADelta
opt = Opt(0.9)
end
for t = 1: 10^5 for t = 1: 10^5
l = loss(rand(10)) θ = Params([w])
back!(l) θ̄ = gradient(() -> loss(rand(10)), θ)
delta = Optimise.apply!(opt, w.data, w.grad) Optimise.update!(opt, θ, θ̄)
w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
end end