![bors[bot]](/assets/img/avatar_default.png)
865: Functor r=MikeInnes a=MikeInnes This refactors our current `@treelike` infrastructure. It somewhat formalises what we're doing around the idea of a Flux model as a functor, i.e. something that can be mapped over. This is much more flexible than what we had before, and avoids some issues. It allows layers to have state that isn't mappable; it allows for dispatch when walking the tree, which means layers like `BatchNorm` can have non-trainable parameters; and it also allows for zipped mapping like `fmap(+, xs, ys)`, which isn't implemented yet but will be useful for the new optimisers work. The main downside is that the term `functor` has been previously used in the Julia community as a malapropism for "thing that behaves like a function"; but hopefully this can start to reduce that usage. Co-authored-by: Mike Innes <mike.j.innes@gmail.com>
64 lines
1.5 KiB
Julia
64 lines
1.5 KiB
Julia
using Flux, CuArrays, Test
|
|
using Flux: pullback
|
|
|
|
@testset for R in [RNN, GRU, LSTM]
|
|
m = R(10, 5) |> gpu
|
|
x = gpu(rand(10))
|
|
(m̄,) = gradient(m -> sum(m(x)), m)
|
|
Flux.reset!(m)
|
|
θ = gradient(() -> sum(m(x)), params(m))
|
|
@test collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
|
|
end
|
|
|
|
@testset "RNN" begin
|
|
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
|
rnn = R(10, 5)
|
|
curnn = fmap(gpu, rnn)
|
|
|
|
Flux.reset!(rnn)
|
|
Flux.reset!(curnn)
|
|
x = batch_size == 1 ?
|
|
rand(10) :
|
|
rand(10, batch_size)
|
|
cux = gpu(x)
|
|
|
|
y, back = pullback((r, x) -> (r(x)), rnn, x)
|
|
cuy, cuback = pullback((r, x) -> (r(x)), curnn, cux)
|
|
|
|
@test y ≈ collect(cuy)
|
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
|
|
|
ȳ = randn(size(y))
|
|
m̄, x̄ = back(ȳ)
|
|
cum̄, cux̄ = cuback(gpu(ȳ))
|
|
|
|
m̄[].cell[].Wi
|
|
|
|
m̄[].state
|
|
cum̄[].state
|
|
|
|
@test x̄ ≈ collect(cux̄)
|
|
@test m̄[].cell[].Wi ≈ collect(cum̄[].cell[].Wi)
|
|
@test m̄[].cell[].Wh ≈ collect(cum̄[].cell[].Wh)
|
|
@test m̄[].cell[].b ≈ collect(cum̄[].cell[].b)
|
|
if m̄[].state isa Tuple
|
|
for (x, cx) in zip(m̄[].state, cum̄[].state)
|
|
@test x ≈ collect(cx)
|
|
end
|
|
else
|
|
@test m̄[].state ≈ collect(cum̄[].state)
|
|
end
|
|
|
|
Flux.reset!(rnn)
|
|
Flux.reset!(curnn)
|
|
ohx = batch_size == 1 ?
|
|
Flux.onehot(rand(1:10), 1:10) :
|
|
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
|
cuohx = gpu(ohx)
|
|
y = (rnn(ohx); rnn(ohx))
|
|
cuy = (curnn(cuohx); curnn(cuohx))
|
|
|
|
@test y ≈ collect(cuy)
|
|
end
|
|
end
|