more general fmap
This commit is contained in:
parent
32c8698869
commit
e02e320008
|
@ -19,16 +19,16 @@ loss(x, y) # ~ 3
|
|||
|
||||
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapparams`, which allows you to alter all parameters of a model at once.
|
||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
|
||||
|
||||
```julia
|
||||
d = Dense(10, 5, σ)
|
||||
d = mapparams(cu, d)
|
||||
d = fmap(cu, d)
|
||||
d.W # Tracked CuArray
|
||||
d(cu(rand(10))) # CuArray output
|
||||
|
||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
m = mapparams(cu, m)
|
||||
m = fmap(cu, m)
|
||||
d(cu(rand(10)))
|
||||
```
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ using Juno, Requires
|
|||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
SGD, param, params, mapparams
|
||||
SGD, param, params, fmap
|
||||
|
||||
using NNlib
|
||||
export σ, relu, softmax
|
||||
|
|
11
src/tree.jl
11
src/tree.jl
|
@ -11,15 +11,18 @@ end
|
|||
# TODO: prewalk/postwalk with correct caching
|
||||
# This is only correct in general for idempotent functions
|
||||
|
||||
mapparams(f, x::AbstractArray) = f(x)
|
||||
mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
|
||||
isleaf(x) = isempty(children(x))
|
||||
|
||||
forparams(f, x) = (mapparams(x -> (f(x); x), x); return)
|
||||
fmap(f, x) = isleaf(x) ? f(x) : mapchildren(x -> fmap(f, x), x)
|
||||
ffor(f, x) = isleaf(x) ? f(x) : foreach(x -> ffor(f, x), children(x))
|
||||
|
||||
using DataFlow: OSet
|
||||
|
||||
function params(m)
|
||||
ps, seen = [], OSet()
|
||||
forparams(p -> p ∉ seen && (push!(ps, p); push!(seen, p)), m)
|
||||
ffor(m) do p
|
||||
p isa TrackedArray && p ∉ seen &&
|
||||
(push!(ps, p); push!(seen, p))
|
||||
end
|
||||
return ps
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue