Flux.jl/src/treelike.jl

72 lines
1.7 KiB
Julia
Raw Normal View History

2018-01-16 17:58:14 +00:00
import Adapt: adapt
2018-07-09 15:57:44 +00:00
import .Tracker: IdSet
2018-01-16 17:58:14 +00:00
2017-09-27 20:11:21 +00:00
children(x) = ()
mapchildren(f, x) = x
2017-10-31 16:37:41 +00:00
children(x::Tuple) = x
mapchildren(f, x::Tuple) = map(f, x)
2018-07-12 21:43:11 +00:00
function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin
2018-03-06 16:55:42 +00:00
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
2017-09-27 20:11:21 +00:00
end
end
2018-07-12 21:43:11 +00:00
function treelike(T, fs = fieldnames(T))
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
treelike(Base._current_module(), T, fs)
end
macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
end
2017-10-16 23:07:15 +00:00
isleaf(x) = isempty(children(x))
2017-09-27 20:11:21 +00:00
2018-07-12 19:56:51 +00:00
function mapleaves(f, x; cache = IdDict())
2017-10-17 00:08:15 +00:00
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
2017-09-27 20:58:34 +00:00
2018-07-09 15:57:44 +00:00
function prefor(f, x; seen = IdSet())
2017-10-17 00:08:15 +00:00
x seen && return
2017-10-19 16:21:08 +00:00
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))
2017-10-17 00:08:15 +00:00
return
end
2017-09-27 20:11:21 +00:00
function params(m)
2017-10-17 00:08:15 +00:00
ps = []
2018-02-08 16:13:20 +00:00
prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) &&
2018-02-14 21:00:50 +00:00
!any(p -> p === p, ps) && push!(ps, p),
2018-02-08 16:13:20 +00:00
m)
2017-10-10 11:16:32 +00:00
return ps
2017-09-27 20:11:21 +00:00
end
2017-11-07 19:34:35 +00:00
params(m...) = params(m)
2018-02-26 23:10:59 +00:00
2018-03-06 02:45:31 +00:00
function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copy!(data(p), data(x))
end
end
2018-02-26 23:10:59 +00:00
# CPU/GPU movement conveniences
2018-02-28 22:51:08 +00:00
cpu(m) = mapleaves(x -> adapt(Array, x), m)
2018-02-26 23:10:59 +00:00
2018-02-28 22:51:08 +00:00
gpu_adaptor = identity
2018-02-26 23:10:59 +00:00
2018-08-03 11:54:24 +00:00
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
2018-02-28 22:51:08 +00:00
global gpu_adaptor = CuArrays.cu
2018-02-26 23:10:59 +00:00
end
2018-02-28 22:51:08 +00:00
gpu(x) = mapleaves(gpu_adaptor, x)