Merge branch 'master' into patch-3
This commit is contained in:
commit
47c1324476
19
NEWS.md
Normal file
19
NEWS.md
Normal file
@ -0,0 +1,19 @@
|
||||
# v0.8.0
|
||||
|
||||
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
|
||||
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
|
||||
* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/).
|
||||
* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592)
|
||||
* Many docs and bugfixes thanks to @KristofferC and others.
|
||||
* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`.
|
||||
* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615).
|
||||
* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs.
|
||||
|
||||
AD Changes:
|
||||
|
||||
* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files).
|
||||
* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576).
|
||||
|
||||
# v0.7.0
|
||||
|
||||
Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time.
|
@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
@ -144,34 +144,32 @@ BatchNorm(chs::Integer, λ = identity;
|
||||
function (BN::BatchNorm)(x)
|
||||
size(x, ndims(x)-1) == length(BN.β) ||
|
||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
γ, β = BN.γ, BN.β
|
||||
dims = length(size(x))
|
||||
channels = size(x, dims-1)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = channels
|
||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||
|
||||
γ = reshape(BN.γ, affine_shape...)
|
||||
β = reshape(BN.β, affine_shape...)
|
||||
if !BN.active
|
||||
μ = reshape(BN.μ, affine_shape...)
|
||||
σ² = reshape(BN.σ², affine_shape...)
|
||||
ϵ = BN.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
|
||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
|
||||
end
|
||||
|
||||
let λ = BN.λ
|
||||
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
|
||||
# This is intentionally not fused because of an extreme slowdown doing so
|
||||
λ.(temp .+ reshape(β, affine_shape...))
|
||||
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
end
|
||||
end
|
||||
|
||||
@ -188,3 +186,103 @@ function Base.show(io::IO, l::BatchNorm)
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
InstanceNorm(channels::Integer, σ = identity;
|
||||
initβ = zeros, initγ = ones,
|
||||
ϵ = 1e-8, momentum = .1)
|
||||
|
||||
Instance Normalization layer. The `channels` input should be the size of the
|
||||
channel dimension in your data (see below).
|
||||
|
||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||
it's the usual channel dimension.)
|
||||
|
||||
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||
|
||||
Example:
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 64),
|
||||
InstanceNorm(64, relu),
|
||||
Dense(64, 10),
|
||||
InstanceNorm(10),
|
||||
softmax)
|
||||
```
|
||||
"""
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
|
||||
mutable struct InstanceNorm{F,V,W,N}
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
function (in::InstanceNorm)(x)
|
||||
size(x, ndims(x)-1) == length(in.β) ||
|
||||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
ndims(x) > 2 ||
|
||||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
||||
# these are repeated later on depending on the batch size
|
||||
dims = length(size(x))
|
||||
c = size(x, dims-1)
|
||||
bs = size(x, dims)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = c
|
||||
affine_shape[end] = bs
|
||||
m = prod(size(x)[1:end-2])
|
||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||
|
||||
if !in.active
|
||||
μ = expand_inst(in.μ, affine_shape)
|
||||
σ² = expand_inst(in.σ², affine_shape)
|
||||
ϵ = in.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, in.ϵ))
|
||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, in.momentum))
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
||||
end
|
||||
|
||||
let λ = in.λ
|
||||
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||
λ.(γ .* x̂ .+ β)
|
||||
end
|
||||
end
|
||||
|
||||
children(in::InstanceNorm) =
|
||||
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
|
||||
|
||||
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
||||
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
|
||||
|
||||
_testmode!(in::InstanceNorm, test) = (in.active = !test)
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||
print(io, ")")
|
||||
end
|
||||
|
@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
@. v = ρ*v - η*Δ
|
||||
@. Δ = -d
|
||||
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
||||
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
||||
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
@ -321,7 +321,7 @@ end
|
||||
|
||||
WeightDecay() = WeightDecay(0)
|
||||
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * x
|
||||
@. Δ += wd * data(x)
|
||||
end
|
||||
|
@ -1,16 +1,23 @@
|
||||
using Juno
|
||||
import Flux.Tracker: data, grad, back!, update!
|
||||
import Flux.Tracker: Params, gradient, data, update!
|
||||
import Base.depwarn
|
||||
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, apply!(opt, x, copy(data(x̄))))
|
||||
update!(x, -apply!(opt, x, data(x̄)))
|
||||
end
|
||||
|
||||
function _update_params!(opt, xs)
|
||||
function update!(opt, xs::Params, gs)
|
||||
for x in xs
|
||||
Δ = apply!(opt, x.data, x.grad)
|
||||
x.data .-= Δ
|
||||
Δ .= 0
|
||||
update!(opt, x, gs[x])
|
||||
end
|
||||
end
|
||||
|
||||
# Added as an internal API but everyone started using it.
|
||||
function _update_params!(opt, xs)
|
||||
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
||||
for x in xs
|
||||
update!(opt, x, Tracker.grad(x))
|
||||
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
||||
end
|
||||
end
|
||||
|
||||
@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
||||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||
# while training; this just cleans that up.
|
||||
macro interrupts(ex)
|
||||
:(try $(esc(ex))
|
||||
catch e
|
||||
e isa InterruptException || rethrow()
|
||||
throw(e)
|
||||
end)
|
||||
end
|
||||
|
||||
struct StopException <: Exception end
|
||||
"""
|
||||
stop()
|
||||
@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
|
||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
function train!(loss, ps, data, opt; cb = () -> ())
|
||||
ps = Params(ps)
|
||||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
_update_params!(opt, ps)
|
||||
gs = gradient(ps) do
|
||||
loss(d...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
|
@ -62,6 +62,7 @@ macro grad(ex)
|
||||
end
|
||||
|
||||
include("idset.jl")
|
||||
include("params.jl")
|
||||
include("back.jl")
|
||||
include("numeric.jl")
|
||||
include("lib/real.jl")
|
||||
|
@ -1,3 +1,15 @@
|
||||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||
# while training; this just cleans that up.
|
||||
macro interrupts(ex)
|
||||
:(try $(esc(ex))
|
||||
catch e
|
||||
e isa InterruptException || rethrow()
|
||||
throw(e)
|
||||
end)
|
||||
end
|
||||
|
||||
# In-place gradients
|
||||
|
||||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
@ -66,64 +78,34 @@ function back!(x, Δ; once = true)
|
||||
return
|
||||
end
|
||||
|
||||
function extract_grad!(x)
|
||||
x̄ = copy(grad(x))
|
||||
x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄)
|
||||
tracker(x).grad = zero_grad!(grad(x))
|
||||
return x̄
|
||||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(data.(xs))
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
||||
grad.(xs))
|
||||
@interrupts back!(l)
|
||||
extract_grad!.(xs)
|
||||
end
|
||||
|
||||
function gradient_(f, xs::Params)
|
||||
l = f()
|
||||
losscheck(l)
|
||||
@interrupts back!(l)
|
||||
gs = Grads()
|
||||
for x in xs
|
||||
gs[tracker(x)] = extract_grad!(x)
|
||||
end
|
||||
return gs
|
||||
end
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
order::Vector{Any}
|
||||
params::IdSet{Any}
|
||||
Params() = new([], IdSet())
|
||||
end
|
||||
|
||||
@forward Params.order Base.iterate, Base.length
|
||||
|
||||
function Base.push!(ps::Params, x)
|
||||
if !(x in ps.params)
|
||||
push!(ps.order, x)
|
||||
push!(ps.params, x)
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||
|
||||
Params(xs) = push!(Params(), xs...)
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.order, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
struct Grads
|
||||
grads::IdDict{Any,Any}
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
@ -182,8 +164,6 @@ end
|
||||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
||||
|
||||
# Jacobians and Hessians
|
||||
|
||||
import ..Flux
|
||||
|
@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ)
|
||||
return x
|
||||
end
|
||||
|
||||
function update!(x::AbstractArray, Δ)
|
||||
x .+= data(Δ)
|
||||
return x
|
||||
end
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
|
46
src/tracker/params.jl
Normal file
46
src/tracker/params.jl
Normal file
@ -0,0 +1,46 @@
|
||||
struct Params
|
||||
order::Vector{Any}
|
||||
params::IdSet{Any}
|
||||
Params() = new([], IdSet())
|
||||
end
|
||||
|
||||
@forward Params.order Base.iterate, Base.length
|
||||
|
||||
function Base.push!(ps::Params, x)
|
||||
if !(x in ps.params)
|
||||
push!(ps.order, x)
|
||||
push!(ps.params, x)
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||
|
||||
Params(xs) = push!(Params(), xs...)
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
join(io, ps.order, ", ")
|
||||
print(io, "])")
|
||||
end
|
||||
|
||||
struct Grads
|
||||
grads::IdDict{Any,Any}
|
||||
end
|
||||
|
||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
g[tracker(x)]
|
||||
end
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
@ -104,3 +104,99 @@ end
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@testset "InstanceNorm" begin
|
||||
# helper functions
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
# begin tests
|
||||
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
|
||||
@test m.active
|
||||
|
||||
m(x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
# 1.0 4.0
|
||||
# 2.0 5.0
|
||||
# 3.0 6.0
|
||||
#
|
||||
#[:, :, 2] =
|
||||
# 7.0 10.0
|
||||
# 8.0 11.0
|
||||
# 9.0 12.0
|
||||
#
|
||||
# μ will be
|
||||
# (1. + 2. + 3.) / 3 = 2.
|
||||
# (4. + 5. + 6.) / 3 = 5.
|
||||
#
|
||||
# (7. + 8. + 9.) / 3 = 8.
|
||||
# (10. + 11. + 12.) / 3 = 11.
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
|
||||
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||
@test m.μ ≈ [0.5, 0.8]
|
||||
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
# 2-element Array{Float64,1}:
|
||||
# 1.
|
||||
# 1.
|
||||
@test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
affine_shape = collect(sizes)
|
||||
affine_shape[1] = 1
|
||||
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = m(x)
|
||||
@test size(m.μ) == (sizes[end - 1], )
|
||||
@test size(m.σ²) == (sizes[end - 1], )
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -4,21 +4,15 @@ using Flux.Tracker
|
||||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
Momentum()]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt(0.001)
|
||||
if opt isa Descent || opt isa ADAGrad
|
||||
opt = Opt(0.1)
|
||||
end
|
||||
if opt isa ADADelta
|
||||
opt = Opt(0.9)
|
||||
end
|
||||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
θ = Params([w′])
|
||||
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||
Optimise.update!(opt, θ, θ̄)
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user