rename and fix mapleaves
This commit is contained in:
parent
7aa0b43ceb
commit
c764b74eba
|
@ -19,16 +19,16 @@ loss(x, y) # ~ 3
|
|||
|
||||
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
|
||||
|
||||
```julia
|
||||
d = Dense(10, 5, σ)
|
||||
d = fmap(cu, d)
|
||||
d = mapleaves(cu, d)
|
||||
d.W # Tracked CuArray
|
||||
d(cu(rand(10))) # CuArray output
|
||||
|
||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
m = fmap(cu, m)
|
||||
m = mapleaves(cu, m)
|
||||
d(cu(rand(10)))
|
||||
```
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ using Juno, Requires
|
|||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
SGD, param, params, fmap
|
||||
SGD, param, params, mapleaves
|
||||
|
||||
using NNlib
|
||||
export σ, relu, softmax
|
||||
|
|
23
src/tree.jl
23
src/tree.jl
|
@ -8,21 +8,24 @@ function treelike(T, fs = fieldnames(T))
|
|||
end
|
||||
end
|
||||
|
||||
# TODO: prewalk/postwalk with correct caching
|
||||
# This is only correct in general for idempotent functions
|
||||
|
||||
isleaf(x) = isempty(children(x))
|
||||
|
||||
fmap(f, x) = isleaf(x) ? f(x) : mapchildren(x -> fmap(f, x), x)
|
||||
ffor(f, x) = isleaf(x) ? f(x) : foreach(x -> ffor(f, x), children(x))
|
||||
function mapleaves(f, x; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||
end
|
||||
|
||||
using DataFlow: OSet
|
||||
|
||||
function forleaves(f, x; seen = OSet())
|
||||
x ∈ seen && return
|
||||
push!(seen, x)
|
||||
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
|
||||
return
|
||||
end
|
||||
|
||||
function params(m)
|
||||
ps, seen = [], OSet()
|
||||
ffor(m) do p
|
||||
p isa TrackedArray && p ∉ seen &&
|
||||
(push!(ps, p); push!(seen, p))
|
||||
end
|
||||
ps = []
|
||||
forleaves(p -> p isa TrackedArray && push!(ps, p), m)
|
||||
return ps
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue