Merge #1192
1192: Improve `restructure` performance r=dhairyagandhi96 a=MikeInnes A small change, but it significantly improves the performance on the following test case: ```julia julia> VERSION v"1.5.0-DEV.876" julia> using Flux, DiffEqFlux, BenchmarkTools julia> using Flux: mse julia> fastdense = FastDense(784, 32, tanh); julia> p = initial_params(fastdense); julia> dense = Dense(784, 32, tanh); julia> p,re = Flux.destructure(dense); julia> x = rand(Float32, 784, 10); julia> y = rand(Float32, 32, 10); julia> @btime gradient((x,p) -> mse(fastdense(x, p), y), x, p); 505.530 μs (87 allocations: 240.73 KiB) julia> @btime gradient((x,p) -> mse(re(p)(x), y), x, p); 107.796 μs (139 allocations: 340.94 KiB) ``` Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
This commit is contained in:
commit
22d5e318e5
|
@ -246,6 +246,10 @@ function _restructure(m, xs)
|
|||
end
|
||||
end
|
||||
|
||||
@adjoint function _restructure(m, xs)
|
||||
_restructure(m, xs), dm -> (nothing,destructure(dm)[1])
|
||||
end
|
||||
|
||||
"""
|
||||
destructure(m)
|
||||
|
||||
|
|
Loading…
Reference in New Issue