rename and fix mapleaves
This commit is contained in:
parent
7aa0b43ceb
commit
c764b74eba
@ -19,16 +19,16 @@ loss(x, y) # ~ 3
|
|||||||
|
|
||||||
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||||
|
|
||||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
|
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
d = Dense(10, 5, σ)
|
d = Dense(10, 5, σ)
|
||||||
d = fmap(cu, d)
|
d = mapleaves(cu, d)
|
||||||
d.W # Tracked CuArray
|
d.W # Tracked CuArray
|
||||||
d(cu(rand(10))) # CuArray output
|
d(cu(rand(10))) # CuArray output
|
||||||
|
|
||||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||||
m = fmap(cu, m)
|
m = mapleaves(cu, m)
|
||||||
d(cu(rand(10)))
|
d(cu(rand(10)))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ using Juno, Requires
|
|||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM,
|
export Chain, Dense, RNN, LSTM,
|
||||||
SGD, param, params, fmap
|
SGD, param, params, mapleaves
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, relu, softmax
|
export σ, relu, softmax
|
||||||
|
23
src/tree.jl
23
src/tree.jl
@ -8,21 +8,24 @@ function treelike(T, fs = fieldnames(T))
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: prewalk/postwalk with correct caching
|
|
||||||
# This is only correct in general for idempotent functions
|
|
||||||
|
|
||||||
isleaf(x) = isempty(children(x))
|
isleaf(x) = isempty(children(x))
|
||||||
|
|
||||||
fmap(f, x) = isleaf(x) ? f(x) : mapchildren(x -> fmap(f, x), x)
|
function mapleaves(f, x; cache = ObjectIdDict())
|
||||||
ffor(f, x) = isleaf(x) ? f(x) : foreach(x -> ffor(f, x), children(x))
|
haskey(cache, x) && return cache[x]
|
||||||
|
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||||
|
end
|
||||||
|
|
||||||
using DataFlow: OSet
|
using DataFlow: OSet
|
||||||
|
|
||||||
|
function forleaves(f, x; seen = OSet())
|
||||||
|
x ∈ seen && return
|
||||||
|
push!(seen, x)
|
||||||
|
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
function params(m)
|
function params(m)
|
||||||
ps, seen = [], OSet()
|
ps = []
|
||||||
ffor(m) do p
|
forleaves(p -> p isa TrackedArray && push!(ps, p), m)
|
||||||
p isa TrackedArray && p ∉ seen &&
|
|
||||||
(push!(ps, p); push!(seen, p))
|
|
||||||
end
|
|
||||||
return ps
|
return ps
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user