
```julia julia> p Tracked 2×3 Array{Float64,2}: 1.0 3.0 5.0 2.0 4.0 6.0 ``` Before ```julia julia> @benchmark Flux.Tracker.back!(sum($p, 2) ./ size($p, 2), ones(2, 1)) BenchmarkTools.Trial: memory estimate: 3.44 KiB allocs estimate: 75 -------------- minimum time: 20.438 μs (0.00% GC) median time: 21.239 μs (0.00% GC) mean time: 22.354 μs (1.68% GC) maximum time: 3.811 ms (98.51% GC) -------------- samples: 10000 evals/sample: 1 ``` After ```julia julia> @benchmark Flux.Tracker.back!(mean($p, 2), ones(2, 1)) BenchmarkTools.Trial: memory estimate: 1008 bytes allocs estimate: 21 -------------- minimum time: 5.973 μs (0.00% GC) median time: 6.310 μs (0.00% GC) mean time: 6.630 μs (1.96% GC) maximum time: 680.709 μs (97.28% GC) -------------- samples: 10000 evals/sample: 6 ```
55 lines
1.4 KiB
Julia
55 lines
1.4 KiB
Julia
using Flux.Tracker, Base.Test, NNlib
|
||
using Flux.Tracker: gradcheck
|
||
|
||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||
|
||
@testset "Tracker" begin
|
||
|
||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||
|
||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||
|
||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||
|
||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||
|
||
@test gradtest(x -> x', rand(5))
|
||
|
||
@test gradtest(vcat, rand(5), rand(3))
|
||
@test gradtest(vcat, rand(2,3), rand(3,3))
|
||
|
||
@testset "mean" begin
|
||
@test gradtest(mean, rand(2, 3))
|
||
|
||
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
||
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
||
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
||
|
||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||
end
|
||
|
||
@test gradtest(rand(5)) do x
|
||
y = x.^2
|
||
2y + x
|
||
end
|
||
|
||
for T in [Float32, Float64]
|
||
@test isa(param(T(1)), TrackedArray{T, 0})
|
||
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
|
||
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
|
||
end
|
||
|
||
# TODO: do we wand this behaviour ??
|
||
F = typeof(AbstractFloat(1))
|
||
for T in [Int32, Int64]
|
||
@test isa(param(T(1)), TrackedArray{F, 0})
|
||
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
|
||
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
|
||
end
|
||
|
||
end #testset
|