Flux.jl/test/tracker.jl

16 lines
436 B
Julia
Raw Normal View History

2019-03-07 01:33:02 +00:00
using Flux, Test
using Tracker: gradcheck
2017-08-23 00:43:45 +00:00
2018-02-05 18:10:02 +00:00
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
2018-08-11 09:51:07 +00:00
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
2017-08-23 00:43:45 +00:00
2019-03-07 01:33:02 +00:00
@testset "Tracker" begin
2017-09-03 21:10:35 +00:00
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
2017-10-17 16:36:18 +00:00
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
2017-09-03 21:10:35 +00:00
2019-02-08 12:43:50 +00:00
@test gradtest(x -> Flux.normalise(x), rand(4,3))
2019-02-08 13:02:03 +00:00
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
2019-02-08 12:43:50 +00:00
2017-09-07 03:09:32 +00:00
end