functor refactor
This commit is contained in:
parent
2c71fc282b
commit
6529dbcbe6
@ -1,52 +1,67 @@
|
|||||||
import Adapt: adapt, adapt_storage
|
import Adapt: adapt, adapt_storage
|
||||||
import Zygote: IdSet
|
using Zygote: IdSet
|
||||||
|
|
||||||
children(x) = ()
|
functor(x) = (), _ -> x
|
||||||
mapchildren(f, x) = x
|
|
||||||
|
|
||||||
children(x::Tuple) = x
|
functor(x::Tuple) = x, y -> y
|
||||||
children(x::NamedTuple) = x
|
functor(x::NamedTuple) = x, y -> y
|
||||||
mapchildren(f, x::Tuple) = map(f, x)
|
|
||||||
mapchildren(f, x::NamedTuple) = map(f, x)
|
|
||||||
|
|
||||||
function treelike(m::Module, T, fs = fieldnames(T))
|
functor(x::AbstractArray) = x, y -> y
|
||||||
|
functor(x::AbstractArray{<:Number}) = (), _ -> x
|
||||||
|
|
||||||
|
function makefunctor(m::Module, T, fs = fieldnames(T))
|
||||||
@eval m begin
|
@eval m begin
|
||||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
|
||||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
macro treelike(T, fs = nothing)
|
function functorm(T, fs = nothing)
|
||||||
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||||
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
|
||||||
end
|
end
|
||||||
|
|
||||||
isleaf(x) = isempty(children(x))
|
macro functor(args...)
|
||||||
|
functorm(args...)
|
||||||
|
end
|
||||||
|
|
||||||
function mapleaves(f, x; cache = IdDict())
|
isleaf(x) = functor(x)[1] === ()
|
||||||
|
|
||||||
|
function fmap1(f, x)
|
||||||
|
func, re = functor(x)
|
||||||
|
re(map(f, func))
|
||||||
|
end
|
||||||
|
|
||||||
|
function fmap(f, x; cache = IdDict())
|
||||||
haskey(cache, x) && return cache[x]
|
haskey(cache, x) && return cache[x]
|
||||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
|
||||||
end
|
end
|
||||||
|
|
||||||
function prefor(f, x; seen = IdSet())
|
children(m) = functor(m)[1]
|
||||||
x ∈ seen && return
|
|
||||||
|
params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x)
|
||||||
|
|
||||||
|
function params!(p::Params, x, seen = IdSet())
|
||||||
|
x in seen && return
|
||||||
push!(seen, x)
|
push!(seen, x)
|
||||||
f(x)
|
for child in children(x)
|
||||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
params!(p, child, seen)
|
||||||
return
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function params(m)
|
function params(m...)
|
||||||
ps = Params()
|
ps = Params()
|
||||||
prefor(p ->
|
params!(ps, m)
|
||||||
p isa AbstractArray{<:Real} &&
|
|
||||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
|
||||||
m)
|
|
||||||
return ps
|
return ps
|
||||||
end
|
end
|
||||||
|
|
||||||
params(m...) = params(m)
|
# Deprecated stuff
|
||||||
|
macro treelike(args...)
|
||||||
|
functorm(args...)
|
||||||
|
end
|
||||||
|
mapleaves(f, x) = fmap(f, x)
|
||||||
|
|
||||||
|
# function params
|
||||||
|
|
||||||
function loadparams!(m, xs)
|
function loadparams!(m, xs)
|
||||||
for (p, x) in zip(params(m), xs)
|
for (p, x) in zip(params(m), xs)
|
||||||
@ -76,11 +91,3 @@ paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
|||||||
|
|
||||||
f32(m) = paramtype(Float32, m)
|
f32(m) = paramtype(Float32, m)
|
||||||
f64(m) = paramtype(Float64, m)
|
f64(m) = paramtype(Float64, m)
|
||||||
|
|
||||||
# General parameter map
|
|
||||||
|
|
||||||
function mapparams(f, m)
|
|
||||||
mapleaves(m) do x
|
|
||||||
x isa Union{AbstractArray,Number} ? f(x) : x
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
@ -24,8 +24,7 @@ end
|
|||||||
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
|
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
|
||||||
Base.iterate, Base.lastindex
|
Base.iterate, Base.lastindex
|
||||||
|
|
||||||
children(c::Chain) = c.layers
|
functor(c::Chain) = c.layers, ls -> Chain(ls...)
|
||||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
|
||||||
|
|
||||||
applychain(::Tuple{}, x) = x
|
applychain(::Tuple{}, x) = x
|
||||||
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
||||||
|
@ -166,11 +166,7 @@ function (BN::BatchNorm)(x)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
children(BN::BatchNorm) =
|
@functor BatchNorm
|
||||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)
|
|
||||||
|
|
||||||
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
|
||||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
|
|
||||||
|
|
||||||
function Base.show(io::IO, l::BatchNorm)
|
function Base.show(io::IO, l::BatchNorm)
|
||||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||||
@ -261,11 +257,7 @@ function (in::InstanceNorm)(x)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
children(in::InstanceNorm) =
|
@functor InstanceNorm
|
||||||
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)
|
|
||||||
|
|
||||||
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
|
||||||
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
|
|
||||||
|
|
||||||
function Base.show(io::IO, l::InstanceNorm)
|
function Base.show(io::IO, l::InstanceNorm)
|
||||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||||
@ -360,11 +352,7 @@ function(gn::GroupNorm)(x)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
children(gn::GroupNorm) =
|
@functor GroupNorm
|
||||||
(gn.λ, gn.β, gn.γ, gn.μ, gn.σ², gn.ϵ, gn.momentum)
|
|
||||||
|
|
||||||
mapchildren(f, gn::GroupNorm) = # e.g. mapchildren(cu, BN)
|
|
||||||
GroupNorm(gn.G,gn.λ, f(gn.β), f(gn.γ), f(gn.μ), f(gn.σ²), gn.ϵ, gn.momentum)
|
|
||||||
|
|
||||||
function Base.show(io::IO, l::GroupNorm)
|
function Base.show(io::IO, l::GroupNorm)
|
||||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||||
|
@ -52,7 +52,8 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
|||||||
|
|
||||||
rnn.state = hidden(rnn.cell)
|
rnn.state = hidden(rnn.cell)
|
||||||
"""
|
"""
|
||||||
reset!(m) = prefor(x -> x isa Recur && (x.state = x.init), m)
|
reset!(m::Recur) = (m.state = m.init)
|
||||||
|
reset!(m) = foreach(reset!, functor(m)[1])
|
||||||
|
|
||||||
flip(f, xs) = reverse(f.(reverse(xs)))
|
flip(f, xs) = reverse(f.(reverse(xs)))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user