diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 00000000..55f4c96d --- /dev/null +++ b/NEWS.md @@ -0,0 +1,19 @@ +# v0.8.0 + +* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311). +* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption. +* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/). +* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592) +* Many docs and bugfixes thanks to @KristofferC and others. +* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`. +* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615). +* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs. + +AD Changes: + +* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files). +* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576). + +# v0.7.0 + +Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time. diff --git a/src/Flux.jl b/src/Flux.jl index f234950f..c806716d 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, + DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b81226f6..7aeaeade 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -144,34 +144,32 @@ BatchNorm(chs::Integer, λ = identity; function (BN::BatchNorm)(x) size(x, ndims(x)-1) == length(BN.β) || error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") - γ, β = BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1) affine_shape = ones(Int, dims) affine_shape[end-1] = channels m = prod(size(x)[1:end-2]) * size(x)[end] - + γ = reshape(BN.γ, affine_shape...) + β = reshape(BN.β, affine_shape...) if !BN.active μ = reshape(BN.μ, affine_shape...) σ² = reshape(BN.σ², affine_shape...) + ϵ = BN.ϵ else T = eltype(x) - - ϵ = data(convert(T, BN.ϵ)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) μ = mean(x, dims = axes) σ² = sum((x .- μ) .^ 2, dims = axes) ./ m - + ϵ = data(convert(T, BN.ϵ)) # update moving mean/std mtm = data(convert(T, BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) - BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)) + BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :) end let λ = BN.λ - temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) - # This is intentionally not fused because of an extreme slowdown doing so - λ.(temp .+ reshape(β, affine_shape...)) + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) end end @@ -188,3 +186,103 @@ function Base.show(io::IO, l::BatchNorm) (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") end + + +""" + InstanceNorm(channels::Integer, σ = identity; + initβ = zeros, initγ = ones, + ϵ = 1e-8, momentum = .1) + +Instance Normalization layer. The `channels` input should be the size of the +channel dimension in your data (see below). + +Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For +a batch of feature vectors this is just the data dimension, for `WHCN` images +it's the usual channel dimension.) + +`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and +shifts them to have a new mean and variance (corresponding to the learnable, +per-channel `bias` and `scale` parameters). + +See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). + +Example: +```julia +m = Chain( + Dense(28^2, 64), + InstanceNorm(64, relu), + Dense(64, 10), + InstanceNorm(10), + softmax) +``` +""" +expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) + +mutable struct InstanceNorm{F,V,W,N} + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std + ϵ::N + momentum::N + active::Bool +end + +InstanceNorm(chs::Integer, λ = identity; + initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = + InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)), + zeros(chs), ones(chs), ϵ, momentum, true) + +function (in::InstanceNorm)(x) + size(x, ndims(x)-1) == length(in.β) || + error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))") + ndims(x) > 2 || + error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") + # these are repeated later on depending on the batch size + dims = length(size(x)) + c = size(x, dims-1) + bs = size(x, dims) + affine_shape = ones(Int, dims) + affine_shape[end-1] = c + affine_shape[end] = bs + m = prod(size(x)[1:end-2]) + γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) + + if !in.active + μ = expand_inst(in.μ, affine_shape) + σ² = expand_inst(in.σ², affine_shape) + ϵ = in.ϵ + else + T = eltype(x) + + ϵ = data(convert(T, in.ϵ)) + axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) + μ = mean(x, dims = axes) + σ² = mean((x .- μ) .^ 2, dims = axes) + + # update moving mean/std + mtm = data(convert(T, in.momentum)) + in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) + in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) + end + + let λ = in.λ + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) + end +end + +children(in::InstanceNorm) = + (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active) + +mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) + InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active) + +_testmode!(in::InstanceNorm, test) = (in.active = !test) + +function Base.show(io::IO, l::InstanceNorm) + print(io, "InstanceNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 190d684d..40b8fd33 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) function apply!(o::Momentum, x, Δ) η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(x) + v = get!(o.velocity, x, zero(x))::typeof(data(x)) @. v = ρ * v - η * Δ @. Δ = -v end @@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) function apply!(o::Nesterov, x, Δ) η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(x) + v = get!(o.velocity, x, zero(x))::typeof(data(x)) d = @. ρ^2 * v - (1+ρ) * η * Δ @. v = ρ*v - η*Δ @. Δ = -d @@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) function apply!(o::RMSProp, x, Δ) η, ρ = o.eta, o.rho - acc = get!(o.acc, x, zero(x))::typeof(x) + acc = get!(o.acc, x, zero(x))::typeof(data(x)) @. acc = ρ * acc + (1 - ρ) * Δ^2 @. Δ *= η / (√acc + ϵ) end @@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) function apply!(o::ADAGrad, x, Δ) η = o.eta - acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) + acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x)) @. acc += Δ^2 @. Δ *= η / (√acc + ϵ) end @@ -321,7 +321,7 @@ end WeightDecay() = WeightDecay(0) -function apply!(o::WeightDecay, x, Δ) +function apply!(o::WeightDecay, x, Δ) wd = o.wd - @. Δ += wd * x + @. Δ += wd * data(x) end diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 45fde760..ab8be578 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,16 +1,23 @@ using Juno -import Flux.Tracker: data, grad, back!, update! +import Flux.Tracker: Params, gradient, data, update! import Base.depwarn function update!(opt, x, x̄) - update!(x, apply!(opt, x, copy(data(x̄)))) + update!(x, -apply!(opt, x, data(x̄))) end -function _update_params!(opt, xs) +function update!(opt, xs::Params, gs) for x in xs - Δ = apply!(opt, x.data, x.grad) - x.data .-= Δ - Δ .= 0 + update!(opt, x, gs[x]) + end +end + +# Added as an internal API but everyone started using it. +function _update_params!(opt, xs) + depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop) + for x in xs + update!(opt, x, Tracker.grad(x)) + x.tracker.grad = Tracker.zero_grad!(x.tracker.grad) end end @@ -19,16 +26,6 @@ call(f, xs...) = f(xs...) runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) -# The AD generates fairly large backtraces that are unhelpful if you interrupt -# while training; this just cleans that up. -macro interrupts(ex) - :(try $(esc(ex)) - catch e - e isa InterruptException || rethrow() - throw(e) - end) -end - struct StopException <: Exception end """ stop() @@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, ps, data, opt; cb = () -> ()) + ps = Params(ps) cb = runall(cb) - opt = runall(opt) @progress for d in data try - l = loss(d...) - @interrupts back!(l) - _update_params!(opt, ps) + gs = gradient(ps) do + loss(d...) + end + update!(opt, ps, gs) if cb() == :stop depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) break diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 2fbe6437..adceea61 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -62,6 +62,7 @@ macro grad(ex) end include("idset.jl") +include("params.jl") include("back.jl") include("numeric.jl") include("lib/real.jl") diff --git a/src/tracker/back.jl b/src/tracker/back.jl index ef65ecb6..2825a92c 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,3 +1,15 @@ +# The AD generates fairly large backtraces that are unhelpful if you interrupt +# while training; this just cleans that up. +macro interrupts(ex) + :(try $(esc(ex)) + catch e + e isa InterruptException || rethrow() + throw(e) + end) +end + +# In-place gradients + init_grad(x) = zero(x) zero_grad!(x) = zero(x) zero_grad!(x::AbstractArray) = (x .= 0) @@ -66,64 +78,34 @@ function back!(x, Δ; once = true) return end +function extract_grad!(x) + x̄ = copy(grad(x)) + x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄) + tracker(x).grad = zero_grad!(grad(x)) + return x̄ +end + function gradient_(f, xs...) xs = param.(data.(xs)) l = f(xs...) losscheck(l) - back!(l) - nobacksies("Use `gradient(...; nest = true)` for nested derivatives", - grad.(xs)) + @interrupts back!(l) + extract_grad!.(xs) +end + +function gradient_(f, xs::Params) + l = f() + losscheck(l) + @interrupts back!(l) + gs = Grads() + for x in xs + gs[tracker(x)] = extract_grad!(x) + end + return gs end # Out-of-place gradients -struct Params - order::Vector{Any} - params::IdSet{Any} - Params() = new([], IdSet()) -end - -@forward Params.order Base.iterate, Base.length - -function Base.push!(ps::Params, x) - if !(x in ps.params) - push!(ps.order, x) - push!(ps.params, x) - end - return ps -end - -Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) - -Params(xs) = push!(Params(), xs...) - -function Base.show(io::IO, ps::Params) - print(io, "Params([") - join(io, ps.order, ", ") - print(io, "])") -end - -struct Grads - grads::IdDict{Any,Any} -end - -Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") - -Grads() = Grads(IdDict()) - -@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate - -Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) - -Base.getindex(g::Grads, x::Tracked) = g.grads[x] - -function Base.getindex(g::Grads, x) - istracked(x) || error("Object not tracked: $x") - g[tracker(x)] -end - -accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ - function back_(g::Grads, c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || @@ -182,8 +164,6 @@ end gradient(f, xs...; nest = false) = nest ? gradient_nested(f, xs...) : gradient_(f, xs...) -gradient(f, ps::Params) = gradient_nested(f, ps) - # Jacobians and Hessians import ..Flux diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 52b92cf7..e09e99bc 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ) return x end +function update!(x::AbstractArray, Δ) + x .+= data(Δ) + return x +end + # Fallthrough methods for f in :[Base.size, Base.ndims, Base.collect].args diff --git a/src/tracker/params.jl b/src/tracker/params.jl new file mode 100644 index 00000000..7a1db1e9 --- /dev/null +++ b/src/tracker/params.jl @@ -0,0 +1,46 @@ +struct Params + order::Vector{Any} + params::IdSet{Any} + Params() = new([], IdSet()) +end + +@forward Params.order Base.iterate, Base.length + +function Base.push!(ps::Params, x) + if !(x in ps.params) + push!(ps.order, x) + push!(ps.params, x) + end + return ps +end + +Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) + +Params(xs) = push!(Params(), xs...) + +function Base.show(io::IO, ps::Params) + print(io, "Params([") + join(io, ps.order, ", ") + print(io, "])") +end + +struct Grads + grads::IdDict{Any,Any} +end + +Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") + +Grads() = Grads(IdDict()) + +@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate + +Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) + +Base.getindex(g::Grads, x::Tracked) = g.grads[x] + +function Base.getindex(g::Grads, x) + istracked(x) || error("Object not tracked: $x") + g[tracker(x)] +end + +accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 3ef9eb7a..d8629445 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -104,3 +104,99 @@ end @test (@allocated m(x)) < 100_000_000 end end + + +@testset "InstanceNorm" begin + # helper functions + expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) + # begin tests + let m = InstanceNorm(2), sizes = (3, 2, 2), + x = param(reshape(collect(1:prod(sizes)), sizes)) + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + + @test m.active + + m(x) + + #julia> x + #[:, :, 1] = + # 1.0 4.0 + # 2.0 5.0 + # 3.0 6.0 + # + #[:, :, 2] = + # 7.0 10.0 + # 8.0 11.0 + # 9.0 12.0 + # + # μ will be + # (1. + 2. + 3.) / 3 = 2. + # (4. + 5. + 6.) / 3 = 5. + # + # (7. + 8. + 9.) / 3 = 8. + # (10. + 11. + 12.) / 3 = 11. + # + # ∴ update rule with momentum: + # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 + # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 + @test m.μ ≈ [0.5, 0.8] + # momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq + # julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. + # 2-element Array{Float64,1}: + # 1. + # 1. + @test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5) + end + # with activation function + let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), + x = param(reshape(collect(1:prod(sizes)), sizes)) + + affine_shape = collect(sizes) + affine_shape[1] = 1 + + @test m.active + m(x) + + testmode!(m) + @test !m.active + + y = m(x).data + @test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7) + end + + let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3), + x = param(reshape(collect(1:prod(sizes)), sizes)) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), sizes...) + @test m(x) == y + end + + # check that μ, σ², and the output are the correct size for higher rank tensors + let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), + x = param(reshape(collect(1:prod(sizes)), sizes)) + y = m(x) + @test size(m.μ) == (sizes[end - 1], ) + @test size(m.σ²) == (sizes[end - 1], ) + @test size(y) == sizes + end + + # show that instance norm is equal to batch norm when channel and batch dims are squashed + let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6), + x = param(reshape(collect(1:prod(sizes)), sizes)) + @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) + end + + let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); + m(x) + @test (@allocated m(x)) < 100_000_000 + end + +end diff --git a/test/optimise.jl b/test/optimise.jl index abe1971f..0da94929 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -4,21 +4,15 @@ using Flux.Tracker using Test @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum] + @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), + NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(), + Momentum()] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) - opt = Opt(0.001) - if opt isa Descent || opt isa ADAGrad - opt = Opt(0.1) - end - if opt isa ADADelta - opt = Opt(0.9) - end for t = 1: 10^5 - l = loss(rand(10)) - back!(l) - delta = Optimise.apply!(opt, w′.data, w′.grad) - w′.data .-= delta + θ = Params([w′]) + θ̄ = gradient(() -> loss(rand(10)), θ) + Optimise.update!(opt, θ, θ̄) end @test Flux.mse(w, w′) < 0.01 end