Add leaves()
to make it simple to collect all parameters in a model
Useful for implementing things like L2 regularization, typically requires a filtering afterwards such as collecting only `AbstractArrays`, etc...
This commit is contained in:
parent
05b1844419
commit
40bf65ac3f
@ -9,10 +9,26 @@ children(x::NamedTuple) = x
|
|||||||
mapchildren(f, x::Tuple) = map(f, x)
|
mapchildren(f, x::Tuple) = map(f, x)
|
||||||
mapchildren(f, x::NamedTuple) = map(f, x)
|
mapchildren(f, x::NamedTuple) = map(f, x)
|
||||||
|
|
||||||
|
function leaves(model, params = Any[])
|
||||||
|
# Get all children of the current `model` object
|
||||||
|
children = Flux.children(model)
|
||||||
|
# If there are no more children, we know we are at a leaf
|
||||||
|
if isempty(children)
|
||||||
|
push!(params, model)
|
||||||
|
else
|
||||||
|
# If there are more children, recurse
|
||||||
|
for c in children
|
||||||
|
leaves(c, params)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
# Return the `params` collection in the end
|
||||||
|
return params
|
||||||
|
end
|
||||||
|
|
||||||
function treelike(m::Module, T, fs = fieldnames(T))
|
function treelike(m::Module, T, fs = fieldnames(T))
|
||||||
@eval m begin
|
@eval m begin
|
||||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
Flux.mapchildren(f, x::$T) = $T(map(f, children(x))...)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user