rename and fix mapleaves

This commit is contained in:
Mike J Innes 2017-10-17 01:08:15 +01:00
parent 7aa0b43ceb
commit c764b74eba
3 changed files with 17 additions and 14 deletions

View File

@ -19,16 +19,16 @@ loss(x, y) # ~ 3
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapleaves`, which allows you to alter all parameters of a model at once.
```julia
d = Dense(10, 5, σ)
d = fmap(cu, d)
d = mapleaves(cu, d)
d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = fmap(cu, m)
m = mapleaves(cu, m)
d(cu(rand(10)))
```

View File

@ -8,7 +8,7 @@ using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM,
SGD, param, params, fmap
SGD, param, params, mapleaves
using NNlib
export σ, relu, softmax

View File

@ -8,21 +8,24 @@ function treelike(T, fs = fieldnames(T))
end
end
# TODO: prewalk/postwalk with correct caching
# This is only correct in general for idempotent functions
isleaf(x) = isempty(children(x))
fmap(f, x) = isleaf(x) ? f(x) : mapchildren(x -> fmap(f, x), x)
ffor(f, x) = isleaf(x) ? f(x) : foreach(x -> ffor(f, x), children(x))
function mapleaves(f, x; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end
using DataFlow: OSet
function forleaves(f, x; seen = OSet())
x seen && return
push!(seen, x)
isleaf(x) ? f(x) : foreach(x -> forleaves(f, x, seen = seen), children(x))
return
end
function params(m)
ps, seen = [], OSet()
ffor(m) do p
p isa TrackedArray && p seen &&
(push!(ps, p); push!(seen, p))
end
ps = []
forleaves(p -> p isa TrackedArray && push!(ps, p), m)
return ps
end