Flux.jl/src/functor.jl

94 lines
1.9 KiB
Julia
Raw Normal View History

2019-01-25 10:06:37 +00:00
import Adapt: adapt, adapt_storage
2019-09-19 14:22:11 +00:00
using Zygote: IdSet
2018-01-16 17:58:14 +00:00
2019-09-19 14:22:11 +00:00
functor(x) = (), _ -> x
2017-09-27 20:11:21 +00:00
2019-09-19 14:22:11 +00:00
functor(x::Tuple) = x, y -> y
functor(x::NamedTuple) = x, y -> y
2017-10-31 16:37:41 +00:00
2019-09-19 14:22:11 +00:00
functor(x::AbstractArray) = x, y -> y
functor(x::AbstractArray{<:Number}) = (), _ -> x
function makefunctor(m::Module, T, fs = fieldnames(T))
2018-07-12 21:43:11 +00:00
@eval m begin
2019-09-19 14:22:11 +00:00
Flux.functor(x::$T) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
2017-09-27 20:11:21 +00:00
end
end
2019-09-19 14:22:11 +00:00
function functorm(T, fs = nothing)
2018-07-12 21:43:11 +00:00
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
2019-09-19 14:22:11 +00:00
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
end
macro functor(args...)
functorm(args...)
2018-07-12 21:43:11 +00:00
end
2019-09-19 14:22:11 +00:00
isleaf(x) = functor(x)[1] === ()
2017-09-27 20:11:21 +00:00
2019-09-19 14:22:11 +00:00
function fmap1(f, x)
func, re = functor(x)
re(map(f, func))
end
function fmap(f, x; cache = IdDict())
2017-10-17 00:08:15 +00:00
haskey(cache, x) && return cache[x]
2019-09-19 14:22:11 +00:00
cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x)
2017-10-17 00:08:15 +00:00
end
2017-09-27 20:58:34 +00:00
2019-09-19 14:22:11 +00:00
children(m) = functor(m)[1]
params!(p::Params, x::AbstractArray{<:Real}, seen = IdSet()) = push!(p, x)
function params!(p::Params, x, seen = IdSet())
x in seen && return
push!(seen, x)
2019-09-19 14:22:11 +00:00
for child in children(x)
params!(p, child, seen)
end
2017-10-17 00:08:15 +00:00
end
2019-09-19 14:22:11 +00:00
function params(m...)
2018-10-31 15:50:08 +00:00
ps = Params()
2019-09-19 14:22:11 +00:00
params!(ps, m)
2017-10-10 11:16:32 +00:00
return ps
2017-09-27 20:11:21 +00:00
end
2017-11-07 19:34:35 +00:00
2019-09-19 14:22:11 +00:00
# Deprecated stuff
macro treelike(args...)
functorm(args...)
end
mapleaves(f, x) = fmap(f, x)
# function params
2018-02-26 23:10:59 +00:00
2018-03-06 02:45:31 +00:00
function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
2019-03-08 12:13:58 +00:00
copyto!(p, x)
2018-03-06 02:45:31 +00:00
end
end
2018-02-26 23:10:59 +00:00
# CPU/GPU movement conveniences
2018-02-28 22:51:08 +00:00
cpu(m) = mapleaves(x -> adapt(Array, x), m)
2018-02-26 23:10:59 +00:00
const gpu_adaptor = if has_cuarrays()
CuArrays.cu
else
identity
2018-02-26 23:10:59 +00:00
end
2018-02-28 22:51:08 +00:00
gpu(x) = mapleaves(gpu_adaptor, x)
2019-01-25 10:06:37 +00:00
# Precision
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
f32(m) = paramtype(Float32, m)
f64(m) = paramtype(Float64, m)