more general fmap

This commit is contained in:
Mike J Innes 2017-10-17 00:07:15 +01:00
parent 32c8698869
commit e02e320008
3 changed files with 11 additions and 8 deletions

View File

@ -19,16 +19,16 @@ loss(x, y) # ~ 3
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapparams`, which allows you to alter all parameters of a model at once.
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
```julia
d = Dense(10, 5, σ)
d = mapparams(cu, d)
d = fmap(cu, d)
d.W # Tracked CuArray
d(cu(rand(10))) # CuArray output
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
m = mapparams(cu, m)
m = fmap(cu, m)
d(cu(rand(10)))
```

View File

@ -8,7 +8,7 @@ using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM,
SGD, param, params, mapparams
SGD, param, params, fmap
using NNlib
export σ, relu, softmax

View File

@ -11,15 +11,18 @@ end
# TODO: prewalk/postwalk with correct caching
# This is only correct in general for idempotent functions
mapparams(f, x::AbstractArray) = f(x)
mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
isleaf(x) = isempty(children(x))
forparams(f, x) = (mapparams(x -> (f(x); x), x); return)
fmap(f, x) = isleaf(x) ? f(x) : mapchildren(x -> fmap(f, x), x)
ffor(f, x) = isleaf(x) ? f(x) : foreach(x -> ffor(f, x), children(x))
using DataFlow: OSet
function params(m)
ps, seen = [], OSet()
forparams(p -> p seen && (push!(ps, p); push!(seen, p)), m)
ffor(m) do p
p isa TrackedArray && p seen &&
(push!(ps, p); push!(seen, p))
end
return ps
end