From 6eda27919030d9150a960010e3b57331e5359cac Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 14 Apr 2020 13:58:52 +0100 Subject: [PATCH] split out functor --- Manifest.toml | 8 ++++++++ Project.toml | 1 + src/functor.jl | 37 +------------------------------------ 3 files changed, 10 insertions(+), 36 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index bc74cb6e..672b3752 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -156,6 +156,14 @@ git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac" uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.10" +[[Functors]] +deps = ["MacroTools"] +git-tree-sha1 = "58dd223f9ad2601b0e9964fd65fa9d2c7219be41" +repo-rev = "master" +repo-url = "https://github.com/FluxML/Functors.jl" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.1.0" + [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] git-tree-sha1 = "d586762b08dcda13228df8967119b9cb6f22ade5" diff --git a/Project.toml b/Project.toml index 1883d974..9ed0beae 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/src/functor.jl b/src/functor.jl index 0d7c55f1..c97fd737 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,41 +1,6 @@ import Adapt: adapt, adapt_storage using Zygote: IdSet - -functor(x) = (), _ -> x - -functor(x::Tuple) = x, y -> y -functor(x::NamedTuple) = x, y -> y - -functor(x::AbstractArray) = x, y -> y -functor(x::AbstractArray{<:Number}) = (), _ -> x - -function makefunctor(m::Module, T, fs = fieldnames(T)) - @eval m begin - Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...) - end -end - -function functorm(T, fs = nothing) - fs == nothing || isexpr(fs, :tuple) || error("@functor T (a, b)") - fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)] - :(makefunctor(@__MODULE__, $(esc(T)), $(fs...))) -end - -macro functor(args...) - functorm(args...) -end - -isleaf(x) = functor(x)[1] === () - -function fmap1(f, x) - func, re = functor(x) - re(map(f, func)) -end - -function fmap(f, x; cache = IdDict()) - haskey(cache, x) && return cache[x] - cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x) -end +import Functors: @functor, functor, fmap trainable(m) = functor(m)[1]