Flux.jl/test/tracker.jl

45 lines
1.2 KiB
Julia
Raw Normal View History

2017-08-23 01:03:17 +00:00
using Flux.Tracker, Base.Test, NNlib
2017-08-23 00:43:45 +00:00
using Flux.Tracker: gradcheck
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
2017-09-03 21:10:35 +00:00
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
2017-08-23 01:03:17 +00:00
2017-09-03 21:10:35 +00:00
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
2017-10-17 16:36:18 +00:00
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
2017-09-03 21:10:35 +00:00
@test gradtest(x -> x', rand(5))
2017-08-23 16:50:43 +00:00
2017-09-05 06:11:28 +00:00
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(2,3), rand(3,3))
2017-09-07 03:09:32 +00:00
@test gradtest(rand(5)) do x
y = x.^2
2y + x
end
2017-10-12 08:56:23 +00:00
for T in [Float32, Float64]
@test isa(param(T(1)), TrackedArray{T, 0})
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
2017-08-23 00:43:45 +00:00
end
2017-10-12 08:56:23 +00:00
# TODO: do we wand this behaviour ??
F = typeof(AbstractFloat(1))
for T in [Int32, Int64]
@test isa(param(T(1)), TrackedArray{F, 0})
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
end
end #testset