more general fmap
This commit is contained in:
parent
32c8698869
commit
e02e320008
@ -19,16 +19,16 @@ loss(x, y) # ~ 3
|
|||||||
|
|
||||||
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||||
|
|
||||||
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapparams`, which allows you to alter all parameters of a model at once.
|
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
d = Dense(10, 5, σ)
|
d = Dense(10, 5, σ)
|
||||||
d = mapparams(cu, d)
|
d = fmap(cu, d)
|
||||||
d.W # Tracked CuArray
|
d.W # Tracked CuArray
|
||||||
d(cu(rand(10))) # CuArray output
|
d(cu(rand(10))) # CuArray output
|
||||||
|
|
||||||
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||||
m = mapparams(cu, m)
|
m = fmap(cu, m)
|
||||||
d(cu(rand(10)))
|
d(cu(rand(10)))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ using Juno, Requires
|
|||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM,
|
export Chain, Dense, RNN, LSTM,
|
||||||
SGD, param, params, mapparams
|
SGD, param, params, fmap
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, relu, softmax
|
export σ, relu, softmax
|
||||||
|
11
src/tree.jl
11
src/tree.jl
@ -11,15 +11,18 @@ end
|
|||||||
# TODO: prewalk/postwalk with correct caching
|
# TODO: prewalk/postwalk with correct caching
|
||||||
# This is only correct in general for idempotent functions
|
# This is only correct in general for idempotent functions
|
||||||
|
|
||||||
mapparams(f, x::AbstractArray) = f(x)
|
isleaf(x) = isempty(children(x))
|
||||||
mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
|
|
||||||
|
|
||||||
forparams(f, x) = (mapparams(x -> (f(x); x), x); return)
|
fmap(f, x) = isleaf(x) ? f(x) : mapchildren(x -> fmap(f, x), x)
|
||||||
|
ffor(f, x) = isleaf(x) ? f(x) : foreach(x -> ffor(f, x), children(x))
|
||||||
|
|
||||||
using DataFlow: OSet
|
using DataFlow: OSet
|
||||||
|
|
||||||
function params(m)
|
function params(m)
|
||||||
ps, seen = [], OSet()
|
ps, seen = [], OSet()
|
||||||
forparams(p -> p ∉ seen && (push!(ps, p); push!(seen, p)), m)
|
ffor(m) do p
|
||||||
|
p isa TrackedArray && p ∉ seen &&
|
||||||
|
(push!(ps, p); push!(seen, p))
|
||||||
|
end
|
||||||
return ps
|
return ps
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user