865: Functor r=MikeInnes a=MikeInnes

This refactors our current `@treelike` infrastructure. It somewhat formalises what we're doing around the idea of a Flux model as a functor, i.e. something that can be mapped over.

This is much more flexible than what we had before, and avoids some issues. It allows layers to have state that isn't mappable; it allows for dispatch when walking the tree, which means layers like `BatchNorm` can have non-trainable parameters; and it also allows for zipped mapping like `fmap(+, xs, ys)`, which isn't implemented yet but will be useful for the new optimisers work.

The main downside is that the term `functor` has been previously used in the Julia community as a malapropism for "thing that behaves like a function"; but hopefully this can start to reduce that usage.

Co-authored-by: Mike Innes <mike.j.innes@gmail.com>
This commit is contained in:
bors[bot] 2019-09-24 16:36:10 +00:00 committed by GitHub
commit acb6a89245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 131 additions and 127 deletions

View File

@ -25,16 +25,16 @@ loss(x, y) # ~ 3
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
```julia
d = Dense(10, 5, σ)
d = mapleaves(cu, d)
d = fmap(cu, d)
d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = mapleaves(cu, m)
m = fmap(cu, m)
d(cu(rand(10)))
```

View File

@ -215,7 +215,7 @@ m(5) # => 26
Flux provides a set of helpers for custom layers, which you can enable by calling
```julia
Flux.@treelike Affine
Flux.@functor Affine
```
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

View File

@ -11,7 +11,7 @@ export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, mapleaves, cpu, gpu, f32, f64
SkipConnection, params, fmap, cpu, gpu, f32, f64
include("optimise/Optimise.jl")
using .Optimise
@ -35,7 +35,7 @@ end
include("utils.jl")
include("onehot.jl")
include("treelike.jl")
include("functor.jl")
include("layers/stateless.jl")
include("layers/basic.jl")

91
src/functor.jl Normal file
View File

@ -0,0 +1,91 @@
import Adapt: adapt, adapt_storage
using Zygote: IdSet
functor(x) = (), _ -> x
functor(x::Tuple) = x, y -> y
functor(x::NamedTuple) = x, y -> y
functor(x::AbstractArray) = x, y -> y
functor(x::AbstractArray{<:Number}) = (), _ -> x
function makefunctor(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
end
end
function functorm(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
end
macro functor(args...)
functorm(args...)
end
isleaf(x) = functor(x)[1] === ()
function fmap1(f, x)
func, re = functor(x)
re(map(f, func))
end
function fmap(f, x; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
end
trainable(m) = functor(m)[1]
params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x)
function params!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
end
end
function params(m...)
ps = Params()
params!(ps, m)
return ps
end
# Deprecated stuff
macro treelike(args...)
functorm(args...)
end
mapleaves(f, x) = fmap(f, x)
function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end
# CPU/GPU movement conveniences
cpu(m) = fmap(x -> adapt(Array, x), m)
const gpu_adaptor = if has_cuarrays()
CuArrays.cu
else
identity
end
gpu(x) = fmap(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)

View File

@ -24,8 +24,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
functor(c::Chain) = c.layers, ls -> Chain(ls...)
applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
@ -92,7 +91,7 @@ function Dense(in::Integer, out::Integer, σ = identity;
return Dense(initW(out, in), initb(out), σ)
end
@treelike Dense
@functor Dense
function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
@ -131,7 +130,7 @@ end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(initα(in), initβ(in))
@treelike Diagonal
@functor Diagonal
function (a::Diagonal)(x)
α, β = a.α, a.β
@ -184,7 +183,7 @@ function Maxout(f, n_alts)
return Maxout(over)
end
@treelike Maxout
@functor Maxout
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
@ -209,7 +208,7 @@ struct SkipConnection
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end
@treelike SkipConnection
@functor SkipConnection
function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input

View File

@ -45,7 +45,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
Conv(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike Conv
@functor Conv
function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
@ -102,7 +102,7 @@ ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike ConvTranspose
@functor ConvTranspose
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
@ -180,7 +180,7 @@ function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
)
end
@treelike DepthwiseConv
@functor DepthwiseConv
function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
@ -244,7 +244,7 @@ CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
CrossCor(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike CrossCor
@functor CrossCor
function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)

View File

@ -82,7 +82,7 @@ end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))
@treelike LayerNorm
@functor LayerNorm
(a::LayerNorm)(x) = a.diag(normalise(x))
@ -134,6 +134,8 @@ BatchNorm(chs::Integer, λ = identity;
BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum)
trainable(bn::BatchNorm) = (bn.β, bn.γ)
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
@ -166,11 +168,7 @@ function (BN::BatchNorm)(x)
end
end
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
@functor BatchNorm
function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))")
@ -224,6 +222,8 @@ InstanceNorm(chs::Integer, λ = identity;
InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum)
trainable(in::InstanceNorm) = (in.β, in.γ)
function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) ||
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
@ -261,11 +261,7 @@ function (in::InstanceNorm)(x)
end
end
children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
@functor InstanceNorm
function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))")
@ -311,6 +307,8 @@ GroupNorm(chs::Integer, G::Integer, λ = identity;
GroupNorm(G, λ, initβ(chs), initγ(chs),
zeros(G,1), ones(G,1), ϵ, momentum)
trainable(gn::GroupNorm) = (gn.β, gn.γ)
function(gn::GroupNorm)(x)
size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels")
ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work")
@ -360,11 +358,7 @@ function(gn::GroupNorm)(x)
end
end
children(gn::GroupNorm) =
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum)
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum)
@functor GroupNorm
function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")

View File

@ -38,7 +38,7 @@ function (m::Recur)(xs...)
return y
end
@treelike Recur cell, init
@functor Recur cell, init
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
@ -52,7 +52,8 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = hidden(rnn.cell)
"""
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
reset!(m::Recur) = (m.state = m.init)
reset!(m) = foreach(reset!, functor(m)[1])
flip(f, xs) = reverse(f.(reverse(xs)))
@ -79,7 +80,7 @@ end
hidden(m::RNNCell) = m.h
@treelike RNNCell
@functor RNNCell
function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
@ -127,7 +128,7 @@ end
hidden(m::LSTMCell) = (m.h, m.c)
@treelike LSTMCell
@functor LSTMCell
Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
@ -168,7 +169,7 @@ end
hidden(m::GRUCell) = m.h
@treelike GRUCell
@functor GRUCell
Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")

View File

@ -1,86 +0,0 @@
import Adapt: adapt, adapt_storage
import Zygote: IdSet
children(x) = ()
mapchildren(f, x) = x
children(x::Tuple) = x
children(x::NamedTuple) = x
mapchildren(f, x::Tuple) = map(f, x)
mapchildren(f, x::NamedTuple) = map(f, x)
function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
end
end
macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
end
isleaf(x) = isempty(children(x))
function mapleaves(f, x; cache = IdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
function prefor(f, x; seen = IdSet())
x seen && return
push!(seen, x)
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))
return
end
function params(m)
ps = Params()
prefor(p ->
p isa AbstractArray{<:Real} &&
!any(p -> p === p, ps) && push!(ps, p),
m)
return ps
end
params(m...) = params(m)
function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end
# CPU/GPU movement conveniences
cpu(m) = mapleaves(x -> adapt(Array, x), m)
const gpu_adaptor = if has_cuarrays()
CuArrays.cu
else
identity
end
gpu(x) = mapleaves(gpu_adaptor, x)
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)
# General parameter map
function mapparams(f, m)
mapleaves(m) do x
x isa Union{AbstractArray,Number} ? f(x) : x
end
end

View File

@ -13,7 +13,7 @@ end
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
rnn = R(10, 5)
curnn = mapleaves(gpu, rnn)
curnn = fmap(gpu, rnn)
Flux.reset!(rnn)
Flux.reset!(curnn)

View File

@ -42,6 +42,8 @@ end
let m = BatchNorm(2), x = [1.0 3.0 5.0;
2.0 4.0 6.0]
@test length(params(m)) == 2
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
# initial m.σ is 1
@ -109,7 +111,9 @@ end
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
# begin tests
let m = InstanceNorm(2), sizes = (3, 2, 2),
x = reshape(collect(1:prod(sizes)), sizes)
x = reshape(collect(1:prod(sizes)), sizes)
@test length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0] # initβ(2)
@test m.γ == [1, 1] # initγ(2)
@ -192,7 +196,9 @@ end
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions
let m = GroupNorm(4,2), sizes = (3,4,2),
x = reshape(collect(1:prod(sizes)), sizes)
x = reshape(collect(1:prod(sizes)), sizes)
@test length(params(m)) == 2
x = Float64.(x)
@test m.β == [0, 0, 0, 0] # initβ(32)
@test m.γ == [1, 1, 1, 1] # initγ(32)

View File

@ -83,7 +83,6 @@ end
# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
Flux.children(a::Vector{Any}) = Tuple(a)
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
end