2019-01-25 10:06:37 +00:00
|
|
|
|
import Adapt: adapt, adapt_storage
|
2019-03-08 12:56:19 +00:00
|
|
|
|
import Zygote: IdSet
|
2018-01-16 17:58:14 +00:00
|
|
|
|
|
2017-09-27 20:11:21 +00:00
|
|
|
|
children(x) = ()
|
|
|
|
|
mapchildren(f, x) = x
|
|
|
|
|
|
2017-10-31 16:37:41 +00:00
|
|
|
|
children(x::Tuple) = x
|
2019-02-06 16:08:06 +00:00
|
|
|
|
children(x::NamedTuple) = x
|
2017-10-31 16:37:41 +00:00
|
|
|
|
mapchildren(f, x::Tuple) = map(f, x)
|
2019-02-06 16:08:06 +00:00
|
|
|
|
mapchildren(f, x::NamedTuple) = map(f, x)
|
2017-10-31 16:37:41 +00:00
|
|
|
|
|
2018-07-12 21:43:11 +00:00
|
|
|
|
function treelike(m::Module, T, fs = fieldnames(T))
|
|
|
|
|
@eval m begin
|
2018-03-06 16:55:42 +00:00
|
|
|
|
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
|
|
|
|
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
2017-09-27 20:11:21 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2018-07-12 21:43:11 +00:00
|
|
|
|
macro treelike(T, fs = nothing)
|
|
|
|
|
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
|
|
|
|
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
|
|
|
|
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
|
|
|
|
end
|
|
|
|
|
|
2017-10-16 23:07:15 +00:00
|
|
|
|
isleaf(x) = isempty(children(x))
|
2017-09-27 20:11:21 +00:00
|
|
|
|
|
2018-07-12 19:56:51 +00:00
|
|
|
|
function mapleaves(f, x; cache = IdDict())
|
2017-10-17 00:08:15 +00:00
|
|
|
|
haskey(cache, x) && return cache[x]
|
|
|
|
|
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
|
|
|
|
end
|
2017-09-27 20:58:34 +00:00
|
|
|
|
|
2018-07-09 15:57:44 +00:00
|
|
|
|
function prefor(f, x; seen = IdSet())
|
2017-10-17 00:08:15 +00:00
|
|
|
|
x ∈ seen && return
|
2019-07-08 21:11:35 +00:00
|
|
|
|
push!(seen, x)
|
2017-10-19 16:21:08 +00:00
|
|
|
|
f(x)
|
|
|
|
|
foreach(x -> prefor(f, x, seen = seen), children(x))
|
2017-10-17 00:08:15 +00:00
|
|
|
|
return
|
|
|
|
|
end
|
|
|
|
|
|
2017-09-27 20:11:21 +00:00
|
|
|
|
function params(m)
|
2018-10-31 15:50:08 +00:00
|
|
|
|
ps = Params()
|
2018-02-08 16:13:20 +00:00
|
|
|
|
prefor(p ->
|
2019-03-08 12:06:09 +00:00
|
|
|
|
p isa AbstractArray{<:Real} &&
|
2018-02-14 21:00:50 +00:00
|
|
|
|
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
2018-02-08 16:13:20 +00:00
|
|
|
|
m)
|
2017-10-10 11:16:32 +00:00
|
|
|
|
return ps
|
2017-09-27 20:11:21 +00:00
|
|
|
|
end
|
2017-11-07 19:34:35 +00:00
|
|
|
|
|
|
|
|
|
params(m...) = params(m)
|
2018-02-26 23:10:59 +00:00
|
|
|
|
|
2018-03-06 02:45:31 +00:00
|
|
|
|
function loadparams!(m, xs)
|
|
|
|
|
for (p, x) in zip(params(m), xs)
|
|
|
|
|
size(p) == size(x) ||
|
|
|
|
|
error("Expected param size $(size(p)), got $(size(x))")
|
2019-03-08 12:13:58 +00:00
|
|
|
|
copyto!(p, x)
|
2018-03-06 02:45:31 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-26 23:10:59 +00:00
|
|
|
|
# CPU/GPU movement conveniences
|
|
|
|
|
|
2018-02-28 22:51:08 +00:00
|
|
|
|
cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
2018-02-26 23:10:59 +00:00
|
|
|
|
|
2019-08-27 07:33:15 +00:00
|
|
|
|
const gpu_adaptor = if has_cuarrays()
|
|
|
|
|
CuArrays.cu
|
|
|
|
|
else
|
|
|
|
|
identity
|
2018-02-26 23:10:59 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-02-28 22:51:08 +00:00
|
|
|
|
gpu(x) = mapleaves(gpu_adaptor, x)
|
2019-01-25 10:06:37 +00:00
|
|
|
|
|
|
|
|
|
# Precision
|
|
|
|
|
|
|
|
|
|
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs)
|
|
|
|
|
|
|
|
|
|
paramtype(T::Type{<:Real}, m) = mapleaves(x -> adapt(T, x), m)
|
|
|
|
|
|
|
|
|
|
f32(m) = paramtype(Float32, m)
|
|
|
|
|
f64(m) = paramtype(Float64, m)
|
2019-01-25 10:11:46 +00:00
|
|
|
|
|
|
|
|
|
# General parameter map
|
|
|
|
|
|
|
|
|
|
function mapparams(f, m)
|
|
|
|
|
mapleaves(m) do x
|
2019-03-08 12:06:09 +00:00
|
|
|
|
x isa Union{AbstractArray,Number} ? f(x) : x
|
2019-01-25 10:11:46 +00:00
|
|
|
|
end
|
|
|
|
|
end
|