From 6529dbcbe69d3a94b6edd131051ec0df7e26820d Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Thu, 19 Sep 2019 15:22:11 +0100 Subject: [PATCH] functor refactor --- src/functor.jl | 75 ++++++++++++++++++++++------------------- src/layers/basic.jl | 3 +- src/layers/normalise.jl | 18 ++-------- src/layers/recurrent.jl | 3 +- 4 files changed, 47 insertions(+), 52 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 42b10f23..2113d7e4 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,52 +1,67 @@ import Adapt: adapt, adapt_storage -import Zygote: IdSet +using Zygote: IdSet -children(x) = () -mapchildren(f, x) = x +functor(x) = (), _ -> x -children(x::Tuple) = x -children(x::NamedTuple) = x -mapchildren(f, x::Tuple) = map(f, x) -mapchildren(f, x::NamedTuple) = map(f, x) +functor(x::Tuple) = x, y -> y +functor(x::NamedTuple) = x, y -> y -function treelike(m::Module, T, fs = fieldnames(T)) +functor(x::AbstractArray) = x, y -> y +functor(x::AbstractArray{<:Number}) = (), _ -> x + +function makefunctor(m::Module, T, fs = fieldnames(T)) @eval m begin - Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),) - Flux.mapchildren(f, x::$T) = $T(f.($children(x))...) + Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...) end end -macro treelike(T, fs = nothing) +function functorm(T, fs = nothing) fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)") fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] - :(treelike(@__MODULE__, $(esc(T)), $(fs...))) + :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) end -isleaf(x) = isempty(children(x)) +macro functor(args...) + functorm(args...) +end -function mapleaves(f, x; cache = IdDict()) +isleaf(x) = functor(x)[1] === () + +function fmap1(f, x) + func, re = functor(x) + re(map(f, func)) +end + +function fmap(f, x; cache = IdDict()) haskey(cache, x) && return cache[x] - cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x) + cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x) end -function prefor(f, x; seen = IdSet()) - x ∈ seen && return +children(m) = functor(m)[1] + +params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x) + +function params!(p::Params, x, seen = IdSet()) + x in seen && return push!(seen, x) - f(x) - foreach(x -> prefor(f, x, seen = seen), children(x)) - return + for child in children(x) + params!(p, child, seen) + end end -function params(m) +function params(m...) ps = Params() - prefor(p -> - p isa AbstractArray{<:Real} && - !any(p′ -> p′ === p, ps) && push!(ps, p), - m) + params!(ps, m) return ps end -params(m...) = params(m) +# Deprecated stuff +macro treelike(args...) + functorm(args...) +end +mapleaves(f, x) = fmap(f, x) + +# function params function loadparams!(m, xs) for (p, x) in zip(params(m), xs) @@ -76,11 +91,3 @@ paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m) f32(m) = paramtype(Float32, m) f64(m) = paramtype(Float64, m) - -# General parameter map - -function mapparams(f, m) - mapleaves(m) do x - x isa Union{AbstractArray,Number} ? f(x) : x - end -end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0cebead1..1d885916 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -24,8 +24,7 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex -children(c::Chain) = c.layers -mapchildren(f, c::Chain) = Chain(f.(c.layers)...) +functor(c::Chain) = c.layers, ls -> Chain(ls...) applychain(::Tuple{}, x) = x applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 61a62adf..7ea601f8 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -166,11 +166,7 @@ function (BN::BatchNorm)(x) end end -children(BN::BatchNorm) = - (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum) - -mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum) +@functor BatchNorm function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(join(size(l.β), ", "))") @@ -261,11 +257,7 @@ function (in::InstanceNorm)(x) end end -children(in::InstanceNorm) = - (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum) - -mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) - InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum) +@functor InstanceNorm function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") @@ -360,11 +352,7 @@ function(gn::GroupNorm)(x) end end -children(gn::GroupNorm) = - (gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum) - -mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN) - GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum) +@functor GroupNorm function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index b5eea4a4..ad8c6e80 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -52,7 +52,8 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to rnn.state = hidden(rnn.cell) """ -reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m) +reset!(m::Recur) = (m.state = m.init) +reset!(m) = foreach(reset!, functor(m)[1]) flip(f, xs) = reverse(f.(reverse(xs)))