Merge branch 'master' into patch-3
This commit is contained in:
commit
47c1324476
19
NEWS.md
Normal file
19
NEWS.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# v0.8.0
|
||||||
|
|
||||||
|
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
|
||||||
|
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
|
||||||
|
* We now [zero the initial state for RNNs](https://github.com/FluxML/Flux.jl/pull/590/).
|
||||||
|
* [Normalisation can now work on arbitrary `dims`.](https://github.com/FluxML/Flux.jl/pull/592)
|
||||||
|
* Many docs and bugfixes thanks to @KristofferC and others.
|
||||||
|
* [NamedTuples now work like Tuples](https://github.com/FluxML/Flux.jl/pull/603) when doing `mapleaves`.
|
||||||
|
* New "performance tips" [section of the docs](https://github.com/FluxML/Flux.jl/pull/615).
|
||||||
|
* The training loop is [now more readable](https://github.com/FluxML/Flux.jl/pull/651) and better shows how to use the lower-level APIs.
|
||||||
|
|
||||||
|
AD Changes:
|
||||||
|
|
||||||
|
* `det`, `logdet` and `logabsdet` [now have adjoints](https://github.com/FluxML/Flux.jl/pull/596/files).
|
||||||
|
* Support for [PermuteDimsArray](https://github.com/FluxML/Flux.jl/pull/576).
|
||||||
|
|
||||||
|
# v0.7.0
|
||||||
|
|
||||||
|
Despite the heroic efforts of scholars and archeologists, pre-0.7 history is lost to the sands of time.
|
@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
|||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||||
params, mapleaves, cpu, gpu, f32, f64
|
params, mapleaves, cpu, gpu, f32, f64
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
|
@ -144,34 +144,32 @@ BatchNorm(chs::Integer, λ = identity;
|
|||||||
function (BN::BatchNorm)(x)
|
function (BN::BatchNorm)(x)
|
||||||
size(x, ndims(x)-1) == length(BN.β) ||
|
size(x, ndims(x)-1) == length(BN.β) ||
|
||||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
γ, β = BN.γ, BN.β
|
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ones(Int, dims)
|
||||||
affine_shape[end-1] = channels
|
affine_shape[end-1] = channels
|
||||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||||
|
γ = reshape(BN.γ, affine_shape...)
|
||||||
|
β = reshape(BN.β, affine_shape...)
|
||||||
if !BN.active
|
if !BN.active
|
||||||
μ = reshape(BN.μ, affine_shape...)
|
μ = reshape(BN.μ, affine_shape...)
|
||||||
σ² = reshape(BN.σ², affine_shape...)
|
σ² = reshape(BN.σ², affine_shape...)
|
||||||
|
ϵ = BN.ϵ
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
ϵ = data(convert(T, BN.ϵ))
|
|
||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(x, dims = axes)
|
μ = mean(x, dims = axes)
|
||||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||||
|
ϵ = data(convert(T, BN.ϵ))
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, BN.momentum))
|
mtm = data(convert(T, BN.momentum))
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
|
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
let λ = BN.λ
|
||||||
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
|
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
# This is intentionally not fused because of an extreme slowdown doing so
|
λ.(γ .* x̂ .+ β)
|
||||||
λ.(temp .+ reshape(β, affine_shape...))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -188,3 +186,103 @@ function Base.show(io::IO, l::BatchNorm)
|
|||||||
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
InstanceNorm(channels::Integer, σ = identity;
|
||||||
|
initβ = zeros, initγ = ones,
|
||||||
|
ϵ = 1e-8, momentum = .1)
|
||||||
|
|
||||||
|
Instance Normalization layer. The `channels` input should be the size of the
|
||||||
|
channel dimension in your data (see below).
|
||||||
|
|
||||||
|
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||||
|
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||||
|
it's the usual channel dimension.)
|
||||||
|
|
||||||
|
`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and
|
||||||
|
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||||
|
per-channel `bias` and `scale` parameters).
|
||||||
|
|
||||||
|
See [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```julia
|
||||||
|
m = Chain(
|
||||||
|
Dense(28^2, 64),
|
||||||
|
InstanceNorm(64, relu),
|
||||||
|
Dense(64, 10),
|
||||||
|
InstanceNorm(10),
|
||||||
|
softmax)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||||
|
|
||||||
|
mutable struct InstanceNorm{F,V,W,N}
|
||||||
|
λ::F # activation function
|
||||||
|
β::V # bias
|
||||||
|
γ::V # scale
|
||||||
|
μ::W # moving mean
|
||||||
|
σ²::W # moving std
|
||||||
|
ϵ::N
|
||||||
|
momentum::N
|
||||||
|
active::Bool
|
||||||
|
end
|
||||||
|
|
||||||
|
InstanceNorm(chs::Integer, λ = identity;
|
||||||
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
|
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
|
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||||
|
|
||||||
|
function (in::InstanceNorm)(x)
|
||||||
|
size(x, ndims(x)-1) == length(in.β) ||
|
||||||
|
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
|
ndims(x) > 2 ||
|
||||||
|
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
||||||
|
# these are repeated later on depending on the batch size
|
||||||
|
dims = length(size(x))
|
||||||
|
c = size(x, dims-1)
|
||||||
|
bs = size(x, dims)
|
||||||
|
affine_shape = ones(Int, dims)
|
||||||
|
affine_shape[end-1] = c
|
||||||
|
affine_shape[end] = bs
|
||||||
|
m = prod(size(x)[1:end-2])
|
||||||
|
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||||
|
|
||||||
|
if !in.active
|
||||||
|
μ = expand_inst(in.μ, affine_shape)
|
||||||
|
σ² = expand_inst(in.σ², affine_shape)
|
||||||
|
ϵ = in.ϵ
|
||||||
|
else
|
||||||
|
T = eltype(x)
|
||||||
|
|
||||||
|
ϵ = data(convert(T, in.ϵ))
|
||||||
|
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||||
|
μ = mean(x, dims = axes)
|
||||||
|
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||||
|
|
||||||
|
# update moving mean/std
|
||||||
|
mtm = data(convert(T, in.momentum))
|
||||||
|
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
||||||
|
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
||||||
|
end
|
||||||
|
|
||||||
|
let λ = in.λ
|
||||||
|
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
|
λ.(γ .* x̂ .+ β)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
children(in::InstanceNorm) =
|
||||||
|
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
|
||||||
|
|
||||||
|
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
||||||
|
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
|
||||||
|
|
||||||
|
_testmode!(in::InstanceNorm, test) = (in.active = !test)
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::InstanceNorm)
|
||||||
|
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||||
|
(l.λ == identity) || print(io, ", λ = $(l.λ)")
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
|
@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::Momentum, x, Δ)
|
function apply!(o::Momentum, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||||
@. v = ρ * v - η * Δ
|
@. v = ρ * v - η * Δ
|
||||||
@. Δ = -v
|
@. Δ = -v
|
||||||
end
|
end
|
||||||
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::Nesterov, x, Δ)
|
function apply!(o::Nesterov, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||||
@. v = ρ*v - η*Δ
|
@. v = ρ*v - η*Δ
|
||||||
@. Δ = -d
|
@. Δ = -d
|
||||||
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
|||||||
|
|
||||||
function apply!(o::RMSProp, x, Δ)
|
function apply!(o::RMSProp, x, Δ)
|
||||||
η, ρ = o.eta, o.rho
|
η, ρ = o.eta, o.rho
|
||||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
||||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||||
@. Δ *= η / (√acc + ϵ)
|
@. Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
|||||||
|
|
||||||
function apply!(o::ADAGrad, x, Δ)
|
function apply!(o::ADAGrad, x, Δ)
|
||||||
η = o.eta
|
η = o.eta
|
||||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
||||||
@. acc += Δ^2
|
@. acc += Δ^2
|
||||||
@. Δ *= η / (√acc + ϵ)
|
@. Δ *= η / (√acc + ϵ)
|
||||||
end
|
end
|
||||||
@ -323,5 +323,5 @@ WeightDecay() = WeightDecay(0)
|
|||||||
|
|
||||||
function apply!(o::WeightDecay, x, Δ)
|
function apply!(o::WeightDecay, x, Δ)
|
||||||
wd = o.wd
|
wd = o.wd
|
||||||
@. Δ += wd * x
|
@. Δ += wd * data(x)
|
||||||
end
|
end
|
||||||
|
@ -1,16 +1,23 @@
|
|||||||
using Juno
|
using Juno
|
||||||
import Flux.Tracker: data, grad, back!, update!
|
import Flux.Tracker: Params, gradient, data, update!
|
||||||
import Base.depwarn
|
import Base.depwarn
|
||||||
|
|
||||||
function update!(opt, x, x̄)
|
function update!(opt, x, x̄)
|
||||||
update!(x, apply!(opt, x, copy(data(x̄))))
|
update!(x, -apply!(opt, x, data(x̄)))
|
||||||
end
|
end
|
||||||
|
|
||||||
function _update_params!(opt, xs)
|
function update!(opt, xs::Params, gs)
|
||||||
for x in xs
|
for x in xs
|
||||||
Δ = apply!(opt, x.data, x.grad)
|
update!(opt, x, gs[x])
|
||||||
x.data .-= Δ
|
end
|
||||||
Δ .= 0
|
end
|
||||||
|
|
||||||
|
# Added as an internal API but everyone started using it.
|
||||||
|
function _update_params!(opt, xs)
|
||||||
|
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
||||||
|
for x in xs
|
||||||
|
update!(opt, x, Tracker.grad(x))
|
||||||
|
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -19,16 +26,6 @@ call(f, xs...) = f(xs...)
|
|||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
|
|
||||||
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
|
||||||
# while training; this just cleans that up.
|
|
||||||
macro interrupts(ex)
|
|
||||||
:(try $(esc(ex))
|
|
||||||
catch e
|
|
||||||
e isa InterruptException || rethrow()
|
|
||||||
throw(e)
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
|
|
||||||
struct StopException <: Exception end
|
struct StopException <: Exception end
|
||||||
"""
|
"""
|
||||||
stop()
|
stop()
|
||||||
@ -67,13 +64,14 @@ The callback can call `Flux.stop()` to interrupt the training loop.
|
|||||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||||
"""
|
"""
|
||||||
function train!(loss, ps, data, opt; cb = () -> ())
|
function train!(loss, ps, data, opt; cb = () -> ())
|
||||||
|
ps = Params(ps)
|
||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
opt = runall(opt)
|
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
try
|
try
|
||||||
l = loss(d...)
|
gs = gradient(ps) do
|
||||||
@interrupts back!(l)
|
loss(d...)
|
||||||
_update_params!(opt, ps)
|
end
|
||||||
|
update!(opt, ps, gs)
|
||||||
if cb() == :stop
|
if cb() == :stop
|
||||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||||
break
|
break
|
||||||
|
@ -62,6 +62,7 @@ macro grad(ex)
|
|||||||
end
|
end
|
||||||
|
|
||||||
include("idset.jl")
|
include("idset.jl")
|
||||||
|
include("params.jl")
|
||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
include("lib/real.jl")
|
include("lib/real.jl")
|
||||||
|
@ -1,3 +1,15 @@
|
|||||||
|
# The AD generates fairly large backtraces that are unhelpful if you interrupt
|
||||||
|
# while training; this just cleans that up.
|
||||||
|
macro interrupts(ex)
|
||||||
|
:(try $(esc(ex))
|
||||||
|
catch e
|
||||||
|
e isa InterruptException || rethrow()
|
||||||
|
throw(e)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
# In-place gradients
|
||||||
|
|
||||||
init_grad(x) = zero(x)
|
init_grad(x) = zero(x)
|
||||||
zero_grad!(x) = zero(x)
|
zero_grad!(x) = zero(x)
|
||||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||||
@ -66,64 +78,34 @@ function back!(x, Δ; once = true)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function extract_grad!(x)
|
||||||
|
x̄ = copy(grad(x))
|
||||||
|
x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄)
|
||||||
|
tracker(x).grad = zero_grad!(grad(x))
|
||||||
|
return x̄
|
||||||
|
end
|
||||||
|
|
||||||
function gradient_(f, xs...)
|
function gradient_(f, xs...)
|
||||||
xs = param.(data.(xs))
|
xs = param.(data.(xs))
|
||||||
l = f(xs...)
|
l = f(xs...)
|
||||||
losscheck(l)
|
losscheck(l)
|
||||||
back!(l)
|
@interrupts back!(l)
|
||||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
extract_grad!.(xs)
|
||||||
grad.(xs))
|
end
|
||||||
|
|
||||||
|
function gradient_(f, xs::Params)
|
||||||
|
l = f()
|
||||||
|
losscheck(l)
|
||||||
|
@interrupts back!(l)
|
||||||
|
gs = Grads()
|
||||||
|
for x in xs
|
||||||
|
gs[tracker(x)] = extract_grad!(x)
|
||||||
|
end
|
||||||
|
return gs
|
||||||
end
|
end
|
||||||
|
|
||||||
# Out-of-place gradients
|
# Out-of-place gradients
|
||||||
|
|
||||||
struct Params
|
|
||||||
order::Vector{Any}
|
|
||||||
params::IdSet{Any}
|
|
||||||
Params() = new([], IdSet())
|
|
||||||
end
|
|
||||||
|
|
||||||
@forward Params.order Base.iterate, Base.length
|
|
||||||
|
|
||||||
function Base.push!(ps::Params, x)
|
|
||||||
if !(x in ps.params)
|
|
||||||
push!(ps.order, x)
|
|
||||||
push!(ps.params, x)
|
|
||||||
end
|
|
||||||
return ps
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
|
||||||
|
|
||||||
Params(xs) = push!(Params(), xs...)
|
|
||||||
|
|
||||||
function Base.show(io::IO, ps::Params)
|
|
||||||
print(io, "Params([")
|
|
||||||
join(io, ps.order, ", ")
|
|
||||||
print(io, "])")
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Grads
|
|
||||||
grads::IdDict{Any,Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
|
||||||
|
|
||||||
Grads() = Grads(IdDict())
|
|
||||||
|
|
||||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
|
||||||
|
|
||||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
|
||||||
|
|
||||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
|
||||||
|
|
||||||
function Base.getindex(g::Grads, x)
|
|
||||||
istracked(x) || error("Object not tracked: $x")
|
|
||||||
g[tracker(x)]
|
|
||||||
end
|
|
||||||
|
|
||||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|
||||||
|
|
||||||
function back_(g::Grads, c::Call, Δ)
|
function back_(g::Grads, c::Call, Δ)
|
||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||||
@ -182,8 +164,6 @@ end
|
|||||||
gradient(f, xs...; nest = false) =
|
gradient(f, xs...; nest = false) =
|
||||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||||
|
|
||||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
|
||||||
|
|
||||||
# Jacobians and Hessians
|
# Jacobians and Hessians
|
||||||
|
|
||||||
import ..Flux
|
import ..Flux
|
||||||
|
@ -71,6 +71,11 @@ function update!(x::TrackedArray, Δ)
|
|||||||
return x
|
return x
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function update!(x::AbstractArray, Δ)
|
||||||
|
x .+= data(Δ)
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||||
|
46
src/tracker/params.jl
Normal file
46
src/tracker/params.jl
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
struct Params
|
||||||
|
order::Vector{Any}
|
||||||
|
params::IdSet{Any}
|
||||||
|
Params() = new([], IdSet())
|
||||||
|
end
|
||||||
|
|
||||||
|
@forward Params.order Base.iterate, Base.length
|
||||||
|
|
||||||
|
function Base.push!(ps::Params, x)
|
||||||
|
if !(x in ps.params)
|
||||||
|
push!(ps.order, x)
|
||||||
|
push!(ps.params, x)
|
||||||
|
end
|
||||||
|
return ps
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
|
||||||
|
|
||||||
|
Params(xs) = push!(Params(), xs...)
|
||||||
|
|
||||||
|
function Base.show(io::IO, ps::Params)
|
||||||
|
print(io, "Params([")
|
||||||
|
join(io, ps.order, ", ")
|
||||||
|
print(io, "])")
|
||||||
|
end
|
||||||
|
|
||||||
|
struct Grads
|
||||||
|
grads::IdDict{Any,Any}
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||||
|
|
||||||
|
Grads() = Grads(IdDict())
|
||||||
|
|
||||||
|
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||||
|
|
||||||
|
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||||
|
|
||||||
|
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||||
|
|
||||||
|
function Base.getindex(g::Grads, x)
|
||||||
|
istracked(x) || error("Object not tracked: $x")
|
||||||
|
g[tracker(x)]
|
||||||
|
end
|
||||||
|
|
||||||
|
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
@ -104,3 +104,99 @@ end
|
|||||||
@test (@allocated m(x)) < 100_000_000
|
@test (@allocated m(x)) < 100_000_000
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
@testset "InstanceNorm" begin
|
||||||
|
# helper functions
|
||||||
|
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||||
|
# begin tests
|
||||||
|
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
|
||||||
|
@test m.β.data == [0, 0] # initβ(2)
|
||||||
|
@test m.γ.data == [1, 1] # initγ(2)
|
||||||
|
|
||||||
|
@test m.active
|
||||||
|
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
#julia> x
|
||||||
|
#[:, :, 1] =
|
||||||
|
# 1.0 4.0
|
||||||
|
# 2.0 5.0
|
||||||
|
# 3.0 6.0
|
||||||
|
#
|
||||||
|
#[:, :, 2] =
|
||||||
|
# 7.0 10.0
|
||||||
|
# 8.0 11.0
|
||||||
|
# 9.0 12.0
|
||||||
|
#
|
||||||
|
# μ will be
|
||||||
|
# (1. + 2. + 3.) / 3 = 2.
|
||||||
|
# (4. + 5. + 6.) / 3 = 5.
|
||||||
|
#
|
||||||
|
# (7. + 8. + 9.) / 3 = 8.
|
||||||
|
# (10. + 11. + 12.) / 3 = 11.
|
||||||
|
#
|
||||||
|
# ∴ update rule with momentum:
|
||||||
|
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
|
||||||
|
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||||
|
@test m.μ ≈ [0.5, 0.8]
|
||||||
|
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||||
|
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||||
|
# 2-element Array{Float64,1}:
|
||||||
|
# 1.
|
||||||
|
# 1.
|
||||||
|
@test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||||
|
|
||||||
|
testmode!(m)
|
||||||
|
@test !m.active
|
||||||
|
|
||||||
|
x′ = m(x).data
|
||||||
|
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||||
|
end
|
||||||
|
# with activation function
|
||||||
|
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
|
||||||
|
affine_shape = collect(sizes)
|
||||||
|
affine_shape[1] = 1
|
||||||
|
|
||||||
|
@test m.active
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
testmode!(m)
|
||||||
|
@test !m.active
|
||||||
|
|
||||||
|
y = m(x).data
|
||||||
|
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
||||||
|
end
|
||||||
|
|
||||||
|
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
|
y = reshape(m(y), sizes...)
|
||||||
|
@test m(x) == y
|
||||||
|
end
|
||||||
|
|
||||||
|
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||||
|
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
y = m(x)
|
||||||
|
@test size(m.μ) == (sizes[end - 1], )
|
||||||
|
@test size(m.σ²) == (sizes[end - 1], )
|
||||||
|
@test size(y) == sizes
|
||||||
|
end
|
||||||
|
|
||||||
|
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||||
|
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
|
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||||
|
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||||
|
end
|
||||||
|
|
||||||
|
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||||
|
m(x)
|
||||||
|
@test (@allocated m(x)) < 100_000_000
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
||||||
|
@ -4,21 +4,15 @@ using Flux.Tracker
|
|||||||
using Test
|
using Test
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||||
|
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||||
|
Momentum()]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
opt = Opt(0.001)
|
|
||||||
if opt isa Descent || opt isa ADAGrad
|
|
||||||
opt = Opt(0.1)
|
|
||||||
end
|
|
||||||
if opt isa ADADelta
|
|
||||||
opt = Opt(0.9)
|
|
||||||
end
|
|
||||||
for t = 1: 10^5
|
for t = 1: 10^5
|
||||||
l = loss(rand(10))
|
θ = Params([w′])
|
||||||
back!(l)
|
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
Optimise.update!(opt, θ, θ̄)
|
||||||
w′.data .-= delta
|
|
||||||
end
|
end
|
||||||
@test Flux.mse(w, w′) < 0.01
|
@test Flux.mse(w, w′) < 0.01
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user