Merge #1174
1174: Functors r=MikeInnes a=MikeInnes Just splits out the implementation to the [Functors](https://github.com/FluxML/Functors.jl) package, so the same traits can be used elsewhere (e.g. Optimisers.jl) without depending on all of Flux. Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
This commit is contained in:
commit
a84e08cf28
|
@ -156,6 +156,12 @@ git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac"
|
|||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.10"
|
||||
|
||||
[[Functors]]
|
||||
deps = ["MacroTools"]
|
||||
git-tree-sha1 = "f40adc6422f548176bb4351ebd29e4abf773040a"
|
||||
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
|
||||
version = "0.1.0"
|
||||
|
||||
[[GPUArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
|
||||
git-tree-sha1 = "d586762b08dcda13228df8967119b9cb6f22ade5"
|
||||
|
|
|
@ -9,6 +9,7 @@ CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
|||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
|
||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
|
|
|
@ -1,41 +1,6 @@
|
|||
import Adapt: adapt, adapt_storage
|
||||
using Zygote: IdSet
|
||||
|
||||
functor(x) = (), _ -> x
|
||||
|
||||
functor(x::Tuple) = x, y -> y
|
||||
functor(x::NamedTuple) = x, y -> y
|
||||
|
||||
functor(x::AbstractArray) = x, y -> y
|
||||
functor(x::AbstractArray{<:Number}) = (), _ -> x
|
||||
|
||||
function makefunctor(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
|
||||
end
|
||||
end
|
||||
|
||||
function functorm(T, fs = nothing)
|
||||
fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)")
|
||||
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
|
||||
end
|
||||
|
||||
macro functor(args...)
|
||||
functorm(args...)
|
||||
end
|
||||
|
||||
isleaf(x) = functor(x)[1] === ()
|
||||
|
||||
function fmap1(f, x)
|
||||
func, re = functor(x)
|
||||
re(map(f, func))
|
||||
end
|
||||
|
||||
function fmap(f, x; cache = IdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
|
||||
end
|
||||
import Functors: @functor, functor, fmap
|
||||
|
||||
trainable(m) = functor(m)[1]
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ end
|
|||
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
|
||||
Base.iterate, Base.lastindex
|
||||
|
||||
functor(c::Chain) = c.layers, ls -> Chain(ls...)
|
||||
functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
|
||||
|
||||
applychain(::Tuple{}, x) = x
|
||||
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
||||
|
|
Loading…
Reference in New Issue