From 6044421c5cb2660c7baaff1d9500fa0f871e045b Mon Sep 17 00:00:00 2001 From: Sklan Date: Wed, 20 Feb 2019 13:47:31 +0530 Subject: [PATCH 01/10] Update normalise.jl --- src/layers/normalise.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1783d3ef..b9f7a86c 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -113,34 +113,32 @@ BatchNorm(chs::Integer, λ = identity; function (BN::BatchNorm)(x) size(x, ndims(x)-1) == length(BN.β) || error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") - γ, β = BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1) affine_shape = ones(Int, dims) affine_shape[end-1] = channels m = prod(size(x)[1:end-2]) * size(x)[end] - + γ = reshape(BN.γ, affine_shape...) + β = reshape(BN.β, affine_shape...) if !BN.active μ = reshape(BN.μ, affine_shape...) σ² = reshape(BN.σ², affine_shape...) + ϵ = BN.ϵ else T = eltype(x) - - ϵ = data(convert(T, BN.ϵ)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) μ = mean(x, dims = axes) σ² = sum((x .- μ) .^ 2, dims = axes) ./ m - + ϵ = data(convert(T, BN.ϵ)) # update moving mean/std mtm = data(convert(T, BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) - BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)) + BN.σ² = (1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1) end let λ = BN.λ - temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) - # This is intentionally not fused because of an extreme slowdown doing so - λ.(temp .+ reshape(β, affine_shape...)) + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) end end From 7463f0959176912d7a41bc57e0f20e7b14bf4902 Mon Sep 17 00:00:00 2001 From: Sklan Date: Thu, 21 Feb 2019 23:56:19 +0530 Subject: [PATCH 02/10] Update normalise.jl --- src/layers/normalise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index b9f7a86c..e48d26fb 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -133,7 +133,7 @@ function (BN::BatchNorm)(x) # update moving mean/std mtm = data(convert(T, BN.momentum)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) - BN.σ² = (1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1) + BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :) end let λ = BN.λ From 8b4bc7cc5245c6c8cd2a30ad872b9b1800c2e7e9 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Thu, 28 Feb 2019 13:44:54 +0000 Subject: [PATCH 03/10] organise params --- src/tracker/Tracker.jl | 1 + src/tracker/back.jl | 49 ++---------------------------------------- src/tracker/params.jl | 46 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 47 deletions(-) create mode 100644 src/tracker/params.jl diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 2fbe6437..adceea61 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -62,6 +62,7 @@ macro grad(ex) end include("idset.jl") +include("params.jl") include("back.jl") include("numeric.jl") include("lib/real.jl") diff --git a/src/tracker/back.jl b/src/tracker/back.jl index ef65ecb6..0dda0082 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,3 +1,5 @@ +# In-place gradients + init_grad(x) = zero(x) zero_grad!(x) = zero(x) zero_grad!(x::AbstractArray) = (x .= 0) @@ -77,53 +79,6 @@ end # Out-of-place gradients -struct Params - order::Vector{Any} - params::IdSet{Any} - Params() = new([], IdSet()) -end - -@forward Params.order Base.iterate, Base.length - -function Base.push!(ps::Params, x) - if !(x in ps.params) - push!(ps.order, x) - push!(ps.params, x) - end - return ps -end - -Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) - -Params(xs) = push!(Params(), xs...) - -function Base.show(io::IO, ps::Params) - print(io, "Params([") - join(io, ps.order, ", ") - print(io, "])") -end - -struct Grads - grads::IdDict{Any,Any} -end - -Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") - -Grads() = Grads(IdDict()) - -@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate - -Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) - -Base.getindex(g::Grads, x::Tracked) = g.grads[x] - -function Base.getindex(g::Grads, x) - istracked(x) || error("Object not tracked: $x") - g[tracker(x)] -end - -accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ - function back_(g::Grads, c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || diff --git a/src/tracker/params.jl b/src/tracker/params.jl new file mode 100644 index 00000000..7a1db1e9 --- /dev/null +++ b/src/tracker/params.jl @@ -0,0 +1,46 @@ +struct Params + order::Vector{Any} + params::IdSet{Any} + Params() = new([], IdSet()) +end + +@forward Params.order Base.iterate, Base.length + +function Base.push!(ps::Params, x) + if !(x in ps.params) + push!(ps.order, x) + push!(ps.params, x) + end + return ps +end + +Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps) + +Params(xs) = push!(Params(), xs...) + +function Base.show(io::IO, ps::Params) + print(io, "Params([") + join(io, ps.order, ", ") + print(io, "])") +end + +struct Grads + grads::IdDict{Any,Any} +end + +Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") + +Grads() = Grads(IdDict()) + +@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate + +Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) + +Base.getindex(g::Grads, x::Tracked) = g.grads[x] + +function Base.getindex(g::Grads, x) + istracked(x) || error("Object not tracked: $x") + g[tracker(x)] +end + +accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ From cd091ad005c2a1e6563f3689ed6637bd7c4ac152 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Thu, 28 Feb 2019 14:08:01 +0000 Subject: [PATCH 04/10] in place implicit gradients --- src/tracker/back.jl | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 0dda0082..03fe14bb 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -68,13 +68,30 @@ function back!(x, Δ; once = true) return end +function extract_grad!(x) + x̄ = copy(grad(x)) + x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄) + tracker(x).grad = zero_grad!(grad(x)) + return x̄ +end + function gradient_(f, xs...) xs = param.(data.(xs)) l = f(xs...) losscheck(l) back!(l) - nobacksies("Use `gradient(...; nest = true)` for nested derivatives", - grad.(xs)) + extract_grad!.(xs) +end + +function gradient_(f, xs::Params) + l = f() + losscheck(l) + back!(l) + gs = Grads() + for x in xs + gs[tracker(x)] = extract_grad!(x) + end + return gs end # Out-of-place gradients @@ -137,8 +154,6 @@ end gradient(f, xs...; nest = false) = nest ? gradient_nested(f, xs...) : gradient_(f, xs...) -gradient(f, ps::Params) = gradient_nested(f, ps) - # Jacobians and Hessians import ..Flux From 4cf43c0c41ff44dcbc6fea1d94b37c006a65c82d Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Thu, 28 Feb 2019 14:58:42 +0000 Subject: [PATCH 05/10] simpler/nicer training loop --- src/optimise/optimisers.jl | 12 ++++++------ src/optimise/train.jl | 38 ++++++++++++++++++-------------------- src/tracker/back.jl | 14 ++++++++++++-- src/tracker/lib/array.jl | 5 +++++ test/optimise.jl | 18 ++++++------------ 5 files changed, 47 insertions(+), 40 deletions(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 190d684d..40b8fd33 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) function apply!(o::Momentum, x, Δ) η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(x) + v = get!(o.velocity, x, zero(x))::typeof(data(x)) @. v = ρ * v - η * Δ @. Δ = -v end @@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) function apply!(o::Nesterov, x, Δ) η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(x) + v = get!(o.velocity, x, zero(x))::typeof(data(x)) d = @. ρ^2 * v - (1+ρ) * η * Δ @. v = ρ*v - η*Δ @. Δ = -d @@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) function apply!(o::RMSProp, x, Δ) η, ρ = o.eta, o.rho - acc = get!(o.acc, x, zero(x))::typeof(x) + acc = get!(o.acc, x, zero(x))::typeof(data(x)) @. acc = ρ * acc + (1 - ρ) * Δ^2 @. Δ *= η / (√acc + ϵ) end @@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) function apply!(o::ADAGrad, x, Δ) η = o.eta - acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x) + acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x)) @. acc += Δ^2 @. Δ *= η / (√acc + ϵ) end @@ -321,7 +321,7 @@ end WeightDecay() = WeightDecay(0) -function apply!(o::WeightDecay, x, Δ) +function apply!(o::WeightDecay, x, Δ) wd = o.wd - @. Δ += wd * x + @. Δ += wd * data(x) end diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 45fde760..ab8be578 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,16 +1,23 @@ using Juno -import Flux.Tracker: data, grad, back!, update! +import Flux.Tracker: Params, gradient, data, update! import Base.depwarn function update!(opt, x, x̄) - update!(x, apply!(opt, x, copy(data(x̄)))) + update!(x, -apply!(opt, x, data(x̄))) end -function _update_params!(opt, xs) +function update!(opt, xs::Params, gs) for x in xs - Δ = apply!(opt, x.data, x.grad) - x.data .-= Δ - Δ .= 0 + update!(opt, x, gs[x]) + end +end + +# Added as an internal API but everyone started using it. +function _update_params!(opt, xs) + depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop) + for x in xs + update!(opt, x, Tracker.grad(x)) + x.tracker.grad = Tracker.zero_grad!(x.tracker.grad) end end @@ -19,16 +26,6 @@ call(f, xs...) = f(xs...) runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) -# The AD generates fairly large backtraces that are unhelpful if you interrupt -# while training; this just cleans that up. -macro interrupts(ex) - :(try $(esc(ex)) - catch e - e isa InterruptException || rethrow() - throw(e) - end) -end - struct StopException <: Exception end """ stop() @@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, ps, data, opt; cb = () -> ()) + ps = Params(ps) cb = runall(cb) - opt = runall(opt) @progress for d in data try - l = loss(d...) - @interrupts back!(l) - _update_params!(opt, ps) + gs = gradient(ps) do + loss(d...) + end + update!(opt, ps, gs) if cb() == :stop depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) break diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 03fe14bb..2825a92c 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,3 +1,13 @@ +# The AD generates fairly large backtraces that are unhelpful if you interrupt +# while training; this just cleans that up. +macro interrupts(ex) + :(try $(esc(ex)) + catch e + e isa InterruptException || rethrow() + throw(e) + end) +end + # In-place gradients init_grad(x) = zero(x) @@ -79,14 +89,14 @@ function gradient_(f, xs...) xs = param.(data.(xs)) l = f(xs...) losscheck(l) - back!(l) + @interrupts back!(l) extract_grad!.(xs) end function gradient_(f, xs::Params) l = f() losscheck(l) - back!(l) + @interrupts back!(l) gs = Grads() for x in xs gs[tracker(x)] = extract_grad!(x) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 52b92cf7..e09e99bc 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ) return x end +function update!(x::AbstractArray, Δ) + x .+= data(Δ) + return x +end + # Fallthrough methods for f in :[Base.size, Base.ndims, Base.collect].args diff --git a/test/optimise.jl b/test/optimise.jl index abe1971f..0da94929 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -4,21 +4,15 @@ using Flux.Tracker using Test @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum] + @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), + NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(), + Momentum()] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) - opt = Opt(0.001) - if opt isa Descent || opt isa ADAGrad - opt = Opt(0.1) - end - if opt isa ADADelta - opt = Opt(0.9) - end for t = 1: 10^5 - l = loss(rand(10)) - back!(l) - delta = Optimise.apply!(opt, w′.data, w′.grad) - w′.data .-= delta + θ = Params([w′]) + θ̄ = gradient(() -> loss(rand(10)), θ) + Optimise.update!(opt, θ, θ̄) end @test Flux.mse(w, w′) < 0.01 end From 129a708b6f0c36b794729d99535ac56e5a63f4fb Mon Sep 17 00:00:00 2001 From: David Pollack Date: Wed, 20 Feb 2019 14:01:05 +0100 Subject: [PATCH 06/10] instance normalization --- src/Flux.jl | 2 +- src/layers/normalise.jl | 98 ++++++++++++++++++++++++++++++++++++ test/layers/normalisation.jl | 81 +++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index 32982131..a8bd4f0b 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, LayerNorm, BatchNorm, + DepthwiseConv, Dropout, LayerNorm, BatchNorm, InstanceNorm, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index e48d26fb..eaa994b2 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -155,3 +155,101 @@ function Base.show(io::IO, l::BatchNorm) (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") end + + +""" + InstanceNorm(channels::Integer, σ = identity; + initβ = zeros, initγ = ones, + ϵ = 1e-8, momentum = .1) + +Instance Normalization layer. The `channels` input should be the size of the +channel dimension in your data (see below). + +Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For +a batch of feature vectors this is just the data dimension, for `WHCN` images +it's the usual channel dimension.) + +`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and +shifts them to have a new mean and variance (corresponding to the learnable, +per-channel `bias` and `scale` parameters). + +See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). + +Example: +```julia +m = Chain( + Dense(28^2, 64), + InstanceNorm(64, relu), + Dense(64, 10), + InstanceNorm(10), + softmax) +``` +""" +mutable struct InstanceNorm{F,V,W,N} + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std + ϵ::N + momentum::N + active::Bool +end + +InstanceNorm(chs::Integer, λ = identity; + initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = + InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)), + zeros(chs), ones(chs), ϵ, momentum, true) + +function (IN::InstanceNorm)(x) + size(x, ndims(x)-1) == length(IN.β) || + error("InstanceNorm expected $(length(IN.β)) channels, got $(size(x, ndims(x)-1))") + ndims(x) > 2 || + error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") + # these are repeated later on depending on the batch size + γ, β = IN.γ, IN.β + dims = length(size(x)) + c = size(x, dims-1) + bs = size(x, dims) + affine_shape = ones(Int, dims) + affine_shape[end-1] = c + affine_shape[end] = bs + m = prod(size(x)[1:end-2]) + + if !IN.active + μ = reshape(repeat(IN.μ, outer=[bs]), affine_shape...) + σ² = reshape(repeat(IN.σ², outer=[bs]), affine_shape...) + else + T = eltype(x) + + ϵ = data(convert(T, IN.ϵ)) + axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) + μ = mean(x, dims = axes) + σ² = mean((x .- μ) .^ 2, dims = axes) + + # update moving mean/std + mtm = data(convert(T, IN.momentum)) + IN.μ = reshape(mean((1 - mtm) .* repeat(IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) + IN.σ² = reshape(mean(((1 - mtm) .* repeat(IN.σ², outer=[1, bs]) .+ mtm .* reshape(data(σ²), (c, bs)) .* (m / (m - 1))), dims = 2), :) + end + + let λ = IN.λ + temp = reshape(repeat(γ, outer=[bs]), affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ IN.ϵ)) + # This is intentionally not fused because of an extreme slowdown doing so + λ.(temp .+ reshape(repeat(β, outer=[bs]), affine_shape...)) + end +end + +children(IN::InstanceNorm) = + (IN.λ, IN.β, IN.γ, IN.μ, IN.σ², IN.ϵ, IN.momentum, IN.active) + +mapchildren(f, IN::InstanceNorm) = # e.g. mapchildren(cu, IN) + InstanceNorm(IN.λ, f(IN.β), f(IN.γ), f(IN.μ), f(IN.σ²), IN.ϵ, IN.momentum, IN.active) + +_testmode!(IN::InstanceNorm, test) = (IN.active = !test) + +function Base.show(io::IO, l::InstanceNorm) + print(io, "InstanceNorm($(join(size(l.β), ", "))") + (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, ")") +end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 3ef9eb7a..a249a4f4 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -104,3 +104,84 @@ end @test (@allocated m(x)) < 100_000_000 end end + + +@testset "InstanceNorm" begin + # helper functions + expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) + # begin tests + let m = InstanceNorm(2), sizes = (3, 2, 2), + x = param(reshape(collect(1:prod(sizes)), sizes)) + + @test m.β.data == [0, 0] # initβ(2) + @test m.γ.data == [1, 1] # initγ(2) + + @test m.active + + m(x) + + #julia> x + #[:, :, 1] = + # 1.0 4.0 + # 2.0 5.0 + # 3.0 6.0 + # + #[:, :, 2] = + # 7.0 10.0 + # 8.0 11.0 + # 9.0 12.0 + # + # μ will be + # (1. + 2. + 3.) / 3 = 2. + # (4. + 5. + 6.) / 3 = 5. + # + # (7. + 8. + 9.) / 3 = 8. + # (10. + 11. + 12.) / 3 = 11. + # + # ∴ update rule with momentum: + # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 + # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 + @test m.μ ≈ [0.5, 0.8] + # momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq + # julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. + # 2-element Array{Float64,1}: + # 1. + # 1. + @test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. + + testmode!(m) + @test !m.active + + x′ = m(x).data + @test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5) + end + # with activation function + let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), + x = param(reshape(collect(1:prod(sizes)), sizes)) + + affine_shape = collect(sizes) + affine_shape[1] = 1 + + @test m.active + m(x) + + testmode!(m) + @test !m.active + + y = m(x).data + @test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7) + end + + let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3), + x = param(reshape(collect(1:prod(sizes)), sizes)) + y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) + y = reshape(m(y), sizes...) + @test m(x) == y + end + + let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); + m(x) + @test (@allocated m(x)) < 100_000_000 + end + +end From c41f8910052ab4ee85374c339bd8651fd84b4597 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Wed, 20 Feb 2019 14:51:55 +0100 Subject: [PATCH 07/10] changes based on the improved batchnorm in PR#633 --- src/layers/normalise.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index eaa994b2..168f3363 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -185,6 +185,7 @@ m = Chain( softmax) ``` """ +expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) mutable struct InstanceNorm{F,V,W,N} λ::F # activation function β::V # bias @@ -207,7 +208,6 @@ function (IN::InstanceNorm)(x) ndims(x) > 2 || error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") # these are repeated later on depending on the batch size - γ, β = IN.γ, IN.β dims = length(size(x)) c = size(x, dims-1) bs = size(x, dims) @@ -215,10 +215,12 @@ function (IN::InstanceNorm)(x) affine_shape[end-1] = c affine_shape[end] = bs m = prod(size(x)[1:end-2]) + γ, β = expand_inst(IN.γ, affine_shape), expand_inst(IN.β, affine_shape) if !IN.active - μ = reshape(repeat(IN.μ, outer=[bs]), affine_shape...) - σ² = reshape(repeat(IN.σ², outer=[bs]), affine_shape...) + μ = expand_inst(IN.μ, affine_shape) + σ² = expand_inst(IN.σ², affine_shape) + ϵ = IN.ϵ else T = eltype(x) @@ -229,14 +231,13 @@ function (IN::InstanceNorm)(x) # update moving mean/std mtm = data(convert(T, IN.momentum)) - IN.μ = reshape(mean((1 - mtm) .* repeat(IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) - IN.σ² = reshape(mean(((1 - mtm) .* repeat(IN.σ², outer=[1, bs]) .+ mtm .* reshape(data(σ²), (c, bs)) .* (m / (m - 1))), dims = 2), :) + IN.μ = reshape(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) + IN.σ² = reshape(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ reshape(data(σ²), (c, bs)) .* (mtm * m / (m - 1))), dims = 2), :) end let λ = IN.λ - temp = reshape(repeat(γ, outer=[bs]), affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ IN.ϵ)) - # This is intentionally not fused because of an extreme slowdown doing so - λ.(temp .+ reshape(repeat(β, outer=[bs]), affine_shape...)) + x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) + λ.(γ .* x̂ .+ β) end end From 83b4b3a7140592f2a8860cb12af23f55ae407a29 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Wed, 27 Feb 2019 12:03:29 +0100 Subject: [PATCH 08/10] changes based on PR comments --- src/layers/normalise.jl | 5 +++-- test/layers/normalisation.jl | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 168f3363..7562e84f 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -186,6 +186,7 @@ m = Chain( ``` """ expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) + mutable struct InstanceNorm{F,V,W,N} λ::F # activation function β::V # bias @@ -231,8 +232,8 @@ function (IN::InstanceNorm)(x) # update moving mean/std mtm = data(convert(T, IN.momentum)) - IN.μ = reshape(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :) - IN.σ² = reshape(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ reshape(data(σ²), (c, bs)) .* (mtm * m / (m - 1))), dims = 2), :) + IN.μ = dropdims(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) + IN.σ² = dropdims(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) end let λ = IN.λ diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index a249a4f4..d8629445 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -179,7 +179,22 @@ end @test m(x) == y end - let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1); + # check that μ, σ², and the output are the correct size for higher rank tensors + let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), + x = param(reshape(collect(1:prod(sizes)), sizes)) + y = m(x) + @test size(m.μ) == (sizes[end - 1], ) + @test size(m.σ²) == (sizes[end - 1], ) + @test size(y) == sizes + end + + # show that instance norm is equal to batch norm when channel and batch dims are squashed + let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6), + x = param(reshape(collect(1:prod(sizes)), sizes)) + @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) + end + + let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1); m(x) @test (@allocated m(x)) < 100_000_000 end From 7b9b64f1cbfae549acac8f9b4f8996395f322b51 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Thu, 7 Mar 2019 09:44:55 +0100 Subject: [PATCH 09/10] change IN to in --- src/layers/normalise.jl | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 7562e84f..054ca08b 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -203,9 +203,9 @@ InstanceNorm(chs::Integer, λ = identity; InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)), zeros(chs), ones(chs), ϵ, momentum, true) -function (IN::InstanceNorm)(x) - size(x, ndims(x)-1) == length(IN.β) || - error("InstanceNorm expected $(length(IN.β)) channels, got $(size(x, ndims(x)-1))") +function (in::InstanceNorm)(x) + size(x, ndims(x)-1) == length(in.β) || + error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))") ndims(x) > 2 || error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") # these are repeated later on depending on the batch size @@ -216,39 +216,39 @@ function (IN::InstanceNorm)(x) affine_shape[end-1] = c affine_shape[end] = bs m = prod(size(x)[1:end-2]) - γ, β = expand_inst(IN.γ, affine_shape), expand_inst(IN.β, affine_shape) + γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) - if !IN.active - μ = expand_inst(IN.μ, affine_shape) - σ² = expand_inst(IN.σ², affine_shape) - ϵ = IN.ϵ + if !in.active + μ = expand_inst(in.μ, affine_shape) + σ² = expand_inst(in.σ², affine_shape) + ϵ = in.ϵ else T = eltype(x) - ϵ = data(convert(T, IN.ϵ)) + ϵ = data(convert(T, in.ϵ)) axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) μ = mean(x, dims = axes) σ² = mean((x .- μ) .^ 2, dims = axes) # update moving mean/std - mtm = data(convert(T, IN.momentum)) - IN.μ = dropdims(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) - IN.σ² = dropdims(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) + mtm = data(convert(T, in.momentum)) + in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2) + in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2) end - let λ = IN.λ + let λ = in.λ x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) λ.(γ .* x̂ .+ β) end end -children(IN::InstanceNorm) = - (IN.λ, IN.β, IN.γ, IN.μ, IN.σ², IN.ϵ, IN.momentum, IN.active) +children(in::InstanceNorm) = + (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active) -mapchildren(f, IN::InstanceNorm) = # e.g. mapchildren(cu, IN) - InstanceNorm(IN.λ, f(IN.β), f(IN.γ), f(IN.μ), f(IN.σ²), IN.ϵ, IN.momentum, IN.active) +mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) + InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active) -_testmode!(IN::InstanceNorm, test) = (IN.active = !test) +_testmode!(in::InstanceNorm, test) = (in.active = !test) function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") From c8badcd12f69f3510a473442ce8b8e7f725ec7a2 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 7 Mar 2019 11:23:14 +0000 Subject: [PATCH 10/10] add news.md --- NEWS.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 NEWS.md diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 00000000..55f4c96d --- /dev/null +++ b/NEWS.md @@ -0,0 +1,19 @@ +# v0.8.0 + +* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311). +* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption. +* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/). +* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592) +* Many docs and bugfixes thanks to @KristofferC and others. +* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`. +* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615). +* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs. + +AD Changes: + +* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files). +* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576). + +# v0.7.0 + +Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time.