Flux.jl/test/tracker.jl

18 lines
475 B
Julia
Raw Normal View History

2017-08-23 01:03:17 +00:00
using Flux.Tracker, Base.Test, NNlib
2017-08-23 00:43:45 +00:00
using Flux.Tracker: gradcheck
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
2017-08-23 01:03:17 +00:00
gradtest(x -> softmax(x).*(1:3), 3)
gradtest(x -> softmax(x).*(1:3), (3,5))
2017-08-23 00:43:45 +00:00
end