Merge #1192
1192: Improve `restructure` performance r=dhairyagandhi96 a=MikeInnes A small change, but it significantly improves the performance on the following test case: ```julia julia> VERSION v"1.5.0-DEV.876" julia> using Flux, DiffEqFlux, BenchmarkTools julia> using Flux: mse julia> fastdense = FastDense(784, 32, tanh); julia> p = initial_params(fastdense); julia> dense = Dense(784, 32, tanh); julia> p,re = Flux.destructure(dense); julia> x = rand(Float32, 784, 10); julia> y = rand(Float32, 32, 10); julia> @btime gradient((x,p) -> mse(fastdense(x, p), y), x, p); 505.530 μs (87 allocations: 240.73 KiB) julia> @btime gradient((x,p) -> mse(re(p)(x), y), x, p); 107.796 μs (139 allocations: 340.94 KiB) ``` Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
This commit is contained in:
commit
22d5e318e5
|
@ -246,6 +246,10 @@ function _restructure(m, xs)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@adjoint function _restructure(m, xs)
|
||||||
|
_restructure(m, xs), dm -> (nothing,destructure(dm)[1])
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
destructure(m)
|
destructure(m)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue