rename and fix mapleaves

This commit is contained in:
Mike J Innes 2017-10-17 01:08:15 +01:00
parent 7aa0b43ceb
commit c764b74eba
3 changed files with 17 additions and 14 deletions

View File

@ -19,16 +19,16 @@ loss(x, y) # ~ 3
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before. Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once. If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
```julia ```julia
d = Dense(10, 5, σ) d = Dense(10, 5, σ)
d = fmap(cu, d) d = mapleaves(cu, d)
d.W # Tracked CuArray d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = fmap(cu, m) m = mapleaves(cu, m)
d(cu(rand(10))) d(cu(rand(10)))
``` ```

View File

@ -8,7 +8,7 @@ using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, export Chain, Dense, RNN, LSTM,
SGD, param, params, fmap SGD, param, params, mapleaves
using NNlib using NNlib
export σ, relu, softmax export σ, relu, softmax

View File

@ -8,21 +8,24 @@ function treelike(T, fs = fieldnames(T))
end end
end end
# TODO: prewalk/postwalk with correct caching
# This is only correct in general for idempotent functions
isleaf(x) = isempty(children(x)) isleaf(x) = isempty(children(x))
fmap(f, x) = isleaf(x) ? f(x) : mapchildren(x -> fmap(f, x), x) function mapleaves(f, x; cache = ObjectIdDict())
ffor(f, x) = isleaf(x) ? f(x) : foreach(x -> ffor(f, x), children(x)) haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
using DataFlow: OSet using DataFlow: OSet
function forleaves(f, x; seen = OSet())
x seen && return
push!(seen, x)
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
return
end
function params(m) function params(m)
ps, seen = [], OSet() ps = []
ffor(m) do p forleaves(p -> p isa TrackedArray && push!(ps, p), m)
p isa TrackedArray && p seen &&
(push!(ps, p); push!(seen, p))
end
return ps return ps
end end