2017-08-23 01:03:17 +00:00
|
|
|
|
using Flux.Tracker, Base.Test, NNlib
|
2018-02-13 10:20:38 +00:00
|
|
|
|
using Flux.Tracker: TrackedReal, gradcheck
|
2018-02-26 22:43:07 +00:00
|
|
|
|
using NNlib: conv
|
2017-08-23 00:43:45 +00:00
|
|
|
|
|
2018-02-05 18:10:02 +00:00
|
|
|
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
2017-08-23 00:43:45 +00:00
|
|
|
|
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|
|
|
|
|
|
|
|
|
@testset "Tracker" begin
|
|
|
|
|
|
|
|
|
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
|
|
|
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
2018-02-06 11:32:46 +00:00
|
|
|
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
|
|
|
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
2017-08-23 00:43:45 +00:00
|
|
|
|
|
2017-11-08 22:00:19 +00:00
|
|
|
|
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
2017-12-12 17:07:39 +00:00
|
|
|
|
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
2017-11-08 22:00:19 +00:00
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
|
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
2018-03-07 08:24:44 +00:00
|
|
|
|
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
|
|
|
|
@test gradtest(x -> prod(x), (3,4,5))
|
2017-08-23 00:43:45 +00:00
|
|
|
|
|
2017-09-03 21:10:35 +00:00
|
|
|
|
@test gradtest(x -> softmax(x).*(1:3), 3)
|
|
|
|
|
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
2018-02-05 18:50:59 +00:00
|
|
|
|
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
|
|
|
|
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
2018-01-21 07:20:59 +00:00
|
|
|
|
|
2017-09-03 21:10:35 +00:00
|
|
|
|
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
2017-10-17 16:36:18 +00:00
|
|
|
|
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
2017-09-03 21:10:35 +00:00
|
|
|
|
|
|
|
|
|
@test gradtest(x -> x', rand(5))
|
2017-08-23 16:50:43 +00:00
|
|
|
|
|
2017-09-05 06:11:28 +00:00
|
|
|
|
@test gradtest(vcat, rand(5), rand(3))
|
2018-01-16 10:08:45 +00:00
|
|
|
|
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
2018-02-05 17:56:49 +00:00
|
|
|
|
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
2018-05-02 13:51:04 +00:00
|
|
|
|
|
|
|
|
|
@test gradtest((i...) -> cat(1,i...), rand(5), rand(3))
|
|
|
|
|
@test gradtest((i...) -> cat(1,i...), rand(5), rand(8))
|
|
|
|
|
@test gradtest((i...) -> cat(1,i...), rand(5,2),rand(3,2), rand(8,2))
|
|
|
|
|
@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,1))
|
|
|
|
|
@test gradtest((i...) -> cat(2,i...), rand(5,1), rand(5,4))
|
|
|
|
|
@test gradtest((i...) -> cat(2,i...), rand(5,2),rand(5,4), rand(5,8))
|
|
|
|
|
|
2018-02-28 02:19:58 +00:00
|
|
|
|
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
2017-09-05 06:11:28 +00:00
|
|
|
|
|
2018-03-07 12:40:00 +00:00
|
|
|
|
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
|
|
|
|
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
|
|
|
|
|
2018-02-12 22:33:47 +00:00
|
|
|
|
@test gradtest(kron,rand(5), rand(3))
|
|
|
|
|
@test gradtest(kron, rand(5), rand(3), rand(8))
|
2018-02-12 22:27:10 +00:00
|
|
|
|
@test gradtest(kron,rand(5,1), rand(3,1))
|
|
|
|
|
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
|
|
|
|
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
|
|
|
|
|
2018-02-05 18:29:35 +00:00
|
|
|
|
@test gradtest(diagm, rand(3))
|
|
|
|
|
|
2017-10-30 08:21:02 +00:00
|
|
|
|
@testset "mean" begin
|
|
|
|
|
@test gradtest(mean, rand(2, 3))
|
|
|
|
|
|
|
|
|
|
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
|
|
|
|
|
|
|
|
|
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
|
|
|
|
end
|
|
|
|
|
|
2018-04-27 21:14:01 +00:00
|
|
|
|
@testset "maximum" begin
|
|
|
|
|
@test gradtest(maximum, rand(2, 3))
|
|
|
|
|
|
|
|
|
|
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
|
|
|
|
|
|
|
|
|
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "minimum" begin
|
|
|
|
|
@test gradtest(minimum, rand(2, 3))
|
|
|
|
|
|
|
|
|
|
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
|
|
|
|
|
|
|
|
|
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
|
|
|
|
end
|
|
|
|
|
|
2017-11-21 16:04:04 +00:00
|
|
|
|
@test gradtest(x -> std(x), rand(5,5))
|
|
|
|
|
@test gradtest(x -> std(x, 1), rand(5,5))
|
|
|
|
|
|
2017-12-12 17:23:15 +00:00
|
|
|
|
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
2018-02-13 13:31:35 +00:00
|
|
|
|
@test gradtest(dot, rand(5), rand(5))
|
2017-12-12 17:23:15 +00:00
|
|
|
|
|
2018-03-05 17:24:46 +00:00
|
|
|
|
@test gradtest(vecnorm, rand(5))
|
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
|
@test gradtest(rand(5)) do x
|
|
|
|
|
y = x.^2
|
|
|
|
|
2y + x
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-28 12:20:00 +00:00
|
|
|
|
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
2018-02-26 22:43:07 +00:00
|
|
|
|
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
2018-02-28 12:20:00 +00:00
|
|
|
|
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
2018-02-28 13:00:38 +00:00
|
|
|
|
|
2018-02-26 22:43:07 +00:00
|
|
|
|
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
2018-02-28 13:00:38 +00:00
|
|
|
|
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
|
|
|
|
|
2018-02-26 22:43:07 +00:00
|
|
|
|
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
2018-02-28 13:00:38 +00:00
|
|
|
|
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
2017-12-14 18:48:38 +00:00
|
|
|
|
|
2018-01-15 17:00:47 +00:00
|
|
|
|
@test (param([1,2,3]) .< 2) == [true, false, false]
|
|
|
|
|
|
2018-03-13 02:50:56 +00:00
|
|
|
|
@test param(2)^2 == 4.0
|
|
|
|
|
|
2018-04-02 20:09:57 +00:00
|
|
|
|
@testset "reshape" begin
|
|
|
|
|
x = reshape(param(rand(2,2,2)), 4, 2)
|
|
|
|
|
@test x isa TrackedArray
|
|
|
|
|
@test size(x) == (4,2)
|
|
|
|
|
x = reshape(param([1]), (1,:))
|
|
|
|
|
@test x isa TrackedArray
|
|
|
|
|
@test size(x) == (1,1)
|
|
|
|
|
x = reshape(param(rand(2)), (2,:))
|
|
|
|
|
@test x isa TrackedArray
|
|
|
|
|
@test size(x) == (2,1)
|
|
|
|
|
x = reshape(param(rand(2,2)), (1,:,2))
|
|
|
|
|
@test x isa TrackedArray
|
|
|
|
|
@test size(x) == (1,2,2)
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-12 12:31:15 +00:00
|
|
|
|
@testset "Intermediates" begin
|
|
|
|
|
x = param([1])
|
|
|
|
|
l = sum((x .+ x).^2)
|
|
|
|
|
Flux.back!(l)
|
|
|
|
|
@test x.grad == [8]
|
|
|
|
|
x.grad .= 0
|
|
|
|
|
Flux.back!(l)
|
|
|
|
|
@test x.grad == [8]
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-13 10:20:38 +00:00
|
|
|
|
@testset "Fallbacks" begin
|
|
|
|
|
xs = param([1 2; 3 4])
|
|
|
|
|
@test similar(xs) isa Matrix{Float64}
|
|
|
|
|
# Remove this test if we do LowerTriangular properly
|
|
|
|
|
L = LowerTriangular(xs)
|
|
|
|
|
@test L*L' isa Matrix{TrackedReal{Float64}}
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-12 15:05:09 +00:00
|
|
|
|
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
|
|
|
|
|
2018-02-28 14:25:32 +00:00
|
|
|
|
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
|
|
|
|
|
2018-03-21 11:25:47 +00:00
|
|
|
|
b = param(rand())
|
|
|
|
|
Tracker.back!(b)
|
|
|
|
|
@test Tracker.grad(b) == 1
|
|
|
|
|
|
2017-10-12 08:56:23 +00:00
|
|
|
|
end #testset
|