Flux.jl/src/tree.jl

32 lines
714 B
Julia
Raw Normal View History

2017-09-27 20:11:21 +00:00
children(x) = ()
mapchildren(f, x) = x
function treelike(T, fs = fieldnames(T))
@eval begin
children(x::$T) = ($([:(x.$f) for f in fs]...),)
mapchildren(f, x::$T) = $T(f.(children(x))...)
end
end
2017-10-16 23:07:15 +00:00
isleaf(x) = isempty(children(x))
2017-09-27 20:11:21 +00:00
2017-10-17 00:08:15 +00:00
function mapleaves(f, x; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
2017-09-27 20:58:34 +00:00
using DataFlow: OSet
2017-09-27 20:11:21 +00:00
2017-10-17 00:08:15 +00:00
function forleaves(f, x; seen = OSet())
x seen && return
push!(seen, x)
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
return
end
2017-09-27 20:11:21 +00:00
function params(m)
2017-10-17 00:08:15 +00:00
ps = []
forleaves(p -> p isa TrackedArray && push!(ps, p), m)
2017-10-10 11:16:32 +00:00
return ps
2017-09-27 20:11:21 +00:00
end