more general fmap

This commit is contained in:
Mike J Innes 2017-10-17 00:07:15 +01:00
parent 32c8698869
commit e02e320008
3 changed files with 11 additions and 8 deletions

View File

@ -19,16 +19,16 @@ loss(x, y) # ~ 3
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before. Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapparams`, which allows you to alter all parameters of a model at once. If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
```julia ```julia
d = Dense(10, 5, σ) d = Dense(10, 5, σ)
d = mapparams(cu, d) d = fmap(cu, d)
d.W # Tracked CuArray d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = mapparams(cu, m) m = fmap(cu, m)
d(cu(rand(10))) d(cu(rand(10)))
``` ```

View File

@ -8,7 +8,7 @@ using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, export Chain, Dense, RNN, LSTM,
SGD, param, params, mapparams SGD, param, params, fmap
using NNlib using NNlib
export σ, relu, softmax export σ, relu, softmax

View File

@ -11,15 +11,18 @@ end
# TODO: prewalk/postwalk with correct caching # TODO: prewalk/postwalk with correct caching
# This is only correct in general for idempotent functions # This is only correct in general for idempotent functions
mapparams(f, x::AbstractArray) = f(x) isleaf(x) = isempty(children(x))
mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
forparams(f, x) = (mapparams(x -> (f(x); x), x); return) fmap(f, x) = isleaf(x) ? f(x) : mapchildren(x -> fmap(f, x), x)
ffor(f, x) = isleaf(x) ? f(x) : foreach(x -> ffor(f, x), children(x))
using DataFlow: OSet using DataFlow: OSet
function params(m) function params(m)
ps, seen = [], OSet() ps, seen = [], OSet()
forparams(p -> p seen && (push!(ps, p); push!(seen, p)), m) ffor(m) do p
p isa TrackedArray && p seen &&
(push!(ps, p); push!(seen, p))
end
return ps return ps
end end