Add `leaves()` to make it simple to collect all parameters in a model
Useful for implementing things like L2 regularization, typically requires a filtering afterwards such as collecting only `AbstractArrays`, etc...
This commit is contained in:
parent
05b1844419
commit
40bf65ac3f
|
@ -9,10 +9,26 @@ children(x::NamedTuple) = x
|
|||
mapchildren(f, x::Tuple) = map(f, x)
|
||||
mapchildren(f, x::NamedTuple) = map(f, x)
|
||||
|
||||
function leaves(model, params = Any[])
|
||||
# Get all children of the current `model` object
|
||||
children = Flux.children(model)
|
||||
# If there are no more children, we know we are at a leaf
|
||||
if isempty(children)
|
||||
push!(params, model)
|
||||
else
|
||||
# If there are more children, recurse
|
||||
for c in children
|
||||
leaves(c, params)
|
||||
end
|
||||
end
|
||||
# Return the `params` collection in the end
|
||||
return params
|
||||
end
|
||||
|
||||
function treelike(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
||||
Flux.mapchildren(f, x::$T) = $T(map(f, children(x))...)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue