Compare commits

...

1 Commits

Author SHA1 Message Date
Elliot Saba 40bf65ac3f Add `leaves()` to make it simple to collect all parameters in a model
Useful for implementing things like L2 regularization, typically
requires a filtering afterwards such as collecting only
`AbstractArrays`, etc...
2019-04-15 16:30:03 -07:00
1 changed files with 17 additions and 1 deletions

View File

@ -9,10 +9,26 @@ children(x::NamedTuple) = x
mapchildren(f, x::Tuple) = map(f, x)
mapchildren(f, x::NamedTuple) = map(f, x)
function leaves(model, params = Any[])
# Get all children of the current `model` object
children = Flux.children(model)
# If there are no more children, we know we are at a leaf
if isempty(children)
push!(params, model)
else
# If there are more children, recurse
for c in children
leaves(c, params)
end
end
# Return the `params` collection in the end
return params
end
function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
Flux.mapchildren(f, x::$T) = $T(map(f, children(x))...)
end
end