2018-07-12 19:42:32 +00:00
|
|
|
|
using Flux
|
2018-07-18 13:39:20 +00:00
|
|
|
|
using Flux.Tracker, Test, NNlib
|
2019-01-25 16:13:34 +00:00
|
|
|
|
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
|
2018-09-08 19:44:06 +00:00
|
|
|
|
using NNlib: conv, ∇conv_data, depthwiseconv
|
2018-07-18 13:39:20 +00:00
|
|
|
|
using Printf: @sprintf
|
2019-02-08 07:55:33 +00:00
|
|
|
|
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
|
2018-07-19 08:58:43 +00:00
|
|
|
|
using Statistics: mean, std
|
2018-08-11 09:51:07 +00:00
|
|
|
|
using Random
|
2018-07-18 06:41:10 +00:00
|
|
|
|
# using StatsBase
|
2017-08-23 00:43:45 +00:00
|
|
|
|
|
2018-02-05 18:10:02 +00:00
|
|
|
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
2018-08-11 09:51:07 +00:00
|
|
|
|
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
2017-08-23 00:43:45 +00:00
|
|
|
|
@testset "Tracker" begin
|
|
|
|
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
|
|
|
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
2018-02-06 11:32:46 +00:00
|
|
|
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
|
|
|
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
2018-07-18 06:41:10 +00:00
|
|
|
|
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
|
|
|
|
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
2018-07-17 14:57:39 +00:00
|
|
|
|
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
2018-07-18 06:41:10 +00:00
|
|
|
|
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
|
|
|
|
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
|
|
|
|
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
2018-07-18 18:20:00 +00:00
|
|
|
|
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
2018-03-07 08:24:44 +00:00
|
|
|
|
@test gradtest(x -> prod(x), (3,4,5))
|
2017-08-23 00:43:45 +00:00
|
|
|
|
|
2017-09-03 21:10:35 +00:00
|
|
|
|
@test gradtest(x -> softmax(x).*(1:3), 3)
|
|
|
|
|
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
2018-02-05 18:50:59 +00:00
|
|
|
|
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
|
|
|
|
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
2018-01-21 07:20:59 +00:00
|
|
|
|
|
2017-09-03 21:10:35 +00:00
|
|
|
|
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
2017-10-17 16:36:18 +00:00
|
|
|
|
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
2017-09-03 21:10:35 +00:00
|
|
|
|
|
|
|
|
|
@test gradtest(x -> x', rand(5))
|
2018-10-26 18:06:17 +00:00
|
|
|
|
|
2019-02-08 07:55:33 +00:00
|
|
|
|
@test gradtest(det, (4, 4))
|
2019-02-08 10:22:08 +00:00
|
|
|
|
@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1])
|
2019-02-08 07:55:33 +00:00
|
|
|
|
@test gradtest((x) -> logabsdet(x)[1], (4, 4))
|
|
|
|
|
|
2018-10-26 18:06:17 +00:00
|
|
|
|
@testset "indexing & slicing" begin
|
|
|
|
|
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
|
|
|
|
end
|
|
|
|
|
|
2018-05-02 13:47:30 +00:00
|
|
|
|
function promotiontest(f, A, B, C)
|
|
|
|
|
r0 = f(A, B, C)
|
|
|
|
|
r1 = f(param(A), B, C)
|
2018-05-02 12:57:32 +00:00
|
|
|
|
r2 = f(A, param(B), C)
|
2018-12-19 10:41:39 +00:00
|
|
|
|
r3 = f(A, B, param(C))
|
2018-05-02 13:47:30 +00:00
|
|
|
|
r4 = f(param(A), param(B), param(C))
|
2018-05-02 06:30:11 +00:00
|
|
|
|
|
2018-05-02 12:57:32 +00:00
|
|
|
|
@test !isa(r0, TrackedArray)
|
|
|
|
|
@test all(isa.([r1,r2,r3,r4], TrackedArray))
|
|
|
|
|
@test r1 == r2 == r3 == r4
|
|
|
|
|
@test r0 == Flux.data(r4)
|
2018-05-02 06:30:11 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-05-02 13:54:40 +00:00
|
|
|
|
@testset "concat" begin
|
2018-07-12 19:42:32 +00:00
|
|
|
|
cat1(x...) = cat(x..., dims = 1)
|
|
|
|
|
cat2(x...) = cat(x..., dims = 2)
|
2018-05-02 06:30:11 +00:00
|
|
|
|
|
|
|
|
|
@testset for vcatf in [vcat, cat1]
|
2018-05-02 13:54:40 +00:00
|
|
|
|
@test gradtest(vcatf, rand(5), rand(3))
|
|
|
|
|
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
2018-05-02 12:57:32 +00:00
|
|
|
|
@test gradtest(vcatf, rand(5)', rand(5)')
|
2018-05-02 13:54:40 +00:00
|
|
|
|
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
|
|
|
|
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
2018-05-02 06:30:11 +00:00
|
|
|
|
@test gradtest(vcatf, rand(5), rand(3,1))
|
|
|
|
|
@test gradtest(vcatf, rand(5)', rand(2,5))
|
2018-05-02 13:54:40 +00:00
|
|
|
|
end
|
2018-05-02 06:30:11 +00:00
|
|
|
|
|
2018-08-23 10:29:43 +00:00
|
|
|
|
|
2018-05-02 06:30:11 +00:00
|
|
|
|
@testset for hcatf in [hcat, cat2]
|
2018-05-02 13:54:40 +00:00
|
|
|
|
@test gradtest(hcatf, rand(5), rand(5))
|
2018-05-02 12:57:32 +00:00
|
|
|
|
@test gradtest(hcatf, rand(5)', rand(5)')
|
2018-05-02 13:54:40 +00:00
|
|
|
|
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
|
|
|
|
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
2018-05-02 13:47:30 +00:00
|
|
|
|
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
|
2018-05-02 06:30:11 +00:00
|
|
|
|
@test gradtest(hcatf, rand(5)', rand(1,3))
|
|
|
|
|
@test gradtest(hcatf, rand(5), rand(5,2))
|
2018-05-02 12:57:32 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-07-12 19:42:32 +00:00
|
|
|
|
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
2018-05-02 12:57:32 +00:00
|
|
|
|
@test gradtest(catf, rand(5))
|
|
|
|
|
@test gradtest(catf, rand(5)')
|
|
|
|
|
@test gradtest(catf, rand(2,5))
|
|
|
|
|
@test gradtest(catf, rand(2,5,3))
|
2018-05-02 13:54:40 +00:00
|
|
|
|
end
|
2018-05-02 06:30:11 +00:00
|
|
|
|
|
2018-07-12 19:42:32 +00:00
|
|
|
|
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
2018-05-02 06:30:11 +00:00
|
|
|
|
|
2018-05-02 12:57:32 +00:00
|
|
|
|
@testset "cat($dim, ...)" for dim in 3:5
|
2018-07-12 19:42:32 +00:00
|
|
|
|
catdim = (x...) -> cat(x..., dims = dim)
|
2018-05-02 12:57:32 +00:00
|
|
|
|
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
2018-05-02 13:54:40 +00:00
|
|
|
|
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
2018-05-02 12:57:32 +00:00
|
|
|
|
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
2018-05-02 13:54:40 +00:00
|
|
|
|
end
|
2018-05-02 07:03:54 +00:00
|
|
|
|
|
2018-05-02 12:57:32 +00:00
|
|
|
|
@test !isa(vcat(rand(2)), TrackedArray)
|
|
|
|
|
@test !isa(hcat(rand(2)), TrackedArray)
|
2018-08-11 13:27:56 +00:00
|
|
|
|
@test !isa(cat(rand(2), dims=1), TrackedArray)
|
2018-05-02 12:57:32 +00:00
|
|
|
|
|
2018-07-12 19:42:32 +00:00
|
|
|
|
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
2018-05-02 13:46:01 +00:00
|
|
|
|
|
2018-05-02 13:47:30 +00:00
|
|
|
|
@testset "promotiontest" begin
|
2018-07-18 07:01:06 +00:00
|
|
|
|
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
2018-05-02 12:57:32 +00:00
|
|
|
|
promotiontest(fcat, rand(2), rand(2), rand(2))
|
|
|
|
|
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
|
|
|
|
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
|
|
|
|
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
|
2018-05-02 13:47:30 +00:00
|
|
|
|
end
|
2018-05-02 12:57:32 +00:00
|
|
|
|
|
|
|
|
|
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
|
|
|
|
|
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
|
|
|
|
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
|
|
|
|
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
2018-07-18 13:39:20 +00:00
|
|
|
|
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
2018-05-02 13:46:01 +00:00
|
|
|
|
end
|
2018-08-11 13:27:56 +00:00
|
|
|
|
|
2019-02-04 00:04:48 +00:00
|
|
|
|
@testset "scalars" begin
|
|
|
|
|
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
|
|
|
|
|
@test vcat(1, param([1, 2, 3])) isa TrackedArray
|
|
|
|
|
@test hcat(1, param([1 2 3;])) isa TrackedArray
|
|
|
|
|
@test vcat(param(1), 2) isa TrackedArray
|
|
|
|
|
end
|
|
|
|
|
|
2018-05-02 13:54:40 +00:00
|
|
|
|
end
|
2018-05-02 13:51:04 +00:00
|
|
|
|
|
2018-02-28 02:19:58 +00:00
|
|
|
|
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
2019-01-28 17:19:06 +00:00
|
|
|
|
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
|
2017-09-05 06:11:28 +00:00
|
|
|
|
|
2018-08-23 10:29:43 +00:00
|
|
|
|
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
2018-05-23 00:39:45 +00:00
|
|
|
|
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
|
|
|
|
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
|
|
|
|
|
2018-05-02 13:22:59 +00:00
|
|
|
|
@test gradtest(kron, rand(5), rand(3))
|
2018-02-12 22:33:47 +00:00
|
|
|
|
@test gradtest(kron, rand(5), rand(3), rand(8))
|
2018-05-02 13:22:59 +00:00
|
|
|
|
@test gradtest(kron, rand(5,1), rand(3,1))
|
2018-02-12 22:27:10 +00:00
|
|
|
|
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
|
|
|
|
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
|
|
|
|
|
2018-10-30 14:21:22 +00:00
|
|
|
|
@test gradtest(x -> diagm(0 => x), rand(3))
|
2018-02-05 18:29:35 +00:00
|
|
|
|
|
2018-09-19 12:08:30 +00:00
|
|
|
|
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
|
|
|
|
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
|
|
|
|
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
|
|
|
|
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
|
|
|
|
|
2017-10-30 08:21:02 +00:00
|
|
|
|
@testset "mean" begin
|
|
|
|
|
@test gradtest(mean, rand(2, 3))
|
|
|
|
|
|
2018-07-18 18:20:00 +00:00
|
|
|
|
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
2017-10-30 08:21:02 +00:00
|
|
|
|
|
2018-07-18 18:20:00 +00:00
|
|
|
|
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
2017-10-30 08:21:02 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-04-27 21:14:01 +00:00
|
|
|
|
@testset "maximum" begin
|
|
|
|
|
@test gradtest(maximum, rand(2, 3))
|
|
|
|
|
|
2018-07-18 18:20:00 +00:00
|
|
|
|
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
2018-04-27 21:14:01 +00:00
|
|
|
|
|
2018-07-18 18:20:00 +00:00
|
|
|
|
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
2018-04-27 21:14:01 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "minimum" begin
|
|
|
|
|
@test gradtest(minimum, rand(2, 3))
|
|
|
|
|
|
2018-07-18 18:20:00 +00:00
|
|
|
|
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
|
|
|
|
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
2018-04-27 21:14:01 +00:00
|
|
|
|
|
2018-07-18 18:20:00 +00:00
|
|
|
|
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
2018-04-27 21:14:01 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-11-21 16:04:04 +00:00
|
|
|
|
@test gradtest(x -> std(x), rand(5,5))
|
2018-07-19 08:58:43 +00:00
|
|
|
|
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
2019-02-08 12:28:07 +00:00
|
|
|
|
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))
|
2017-11-21 16:04:04 +00:00
|
|
|
|
|
2019-02-08 12:43:50 +00:00
|
|
|
|
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
2019-02-08 13:02:03 +00:00
|
|
|
|
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
2019-02-08 12:43:50 +00:00
|
|
|
|
|
2017-12-12 17:23:15 +00:00
|
|
|
|
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
2018-02-13 13:31:35 +00:00
|
|
|
|
@test gradtest(dot, rand(5), rand(5))
|
2017-12-12 17:23:15 +00:00
|
|
|
|
|
2018-07-18 13:39:20 +00:00
|
|
|
|
@test gradtest(norm, rand(5))
|
2018-03-05 17:24:46 +00:00
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
|
@test gradtest(rand(5)) do x
|
|
|
|
|
y = x.^2
|
|
|
|
|
2y + x
|
|
|
|
|
end
|
|
|
|
|
|
2018-09-08 19:44:06 +00:00
|
|
|
|
@test gradtest(conv, rand(10, 3, 2), randn(Float64, 2, 3, 2))
|
|
|
|
|
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64, 2, 2, 3, 2))
|
|
|
|
|
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 3, 2))
|
2018-02-28 13:00:38 +00:00
|
|
|
|
|
2018-09-08 19:44:06 +00:00
|
|
|
|
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
|
|
|
|
|
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
|
|
|
|
|
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))
|
2018-02-28 13:00:38 +00:00
|
|
|
|
|
2018-07-09 08:05:30 +00:00
|
|
|
|
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
|
2018-05-30 11:42:16 +00:00
|
|
|
|
|
2018-09-08 19:44:06 +00:00
|
|
|
|
@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
|
|
|
|
|
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64, 2, 2, 2, 3))
|
|
|
|
|
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64, 2, 2, 2, 2, 3))
|
|
|
|
|
|
2018-02-26 22:43:07 +00:00
|
|
|
|
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
2018-02-28 13:00:38 +00:00
|
|
|
|
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
|
|
|
|
|
2018-02-26 22:43:07 +00:00
|
|
|
|
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
2018-02-28 13:00:38 +00:00
|
|
|
|
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
2017-12-14 18:48:38 +00:00
|
|
|
|
|
2019-01-22 10:07:42 +00:00
|
|
|
|
@test gradtest(x -> Float64.(x), 5)
|
|
|
|
|
|
2018-08-25 06:52:08 +00:00
|
|
|
|
@testset "equality & order" begin
|
|
|
|
|
# TrackedReal
|
|
|
|
|
@test param(2)^2 == param(4)
|
|
|
|
|
@test param(2)^2 == 4
|
|
|
|
|
@test 4 == param(2)^2
|
|
|
|
|
|
|
|
|
|
@test param(2)^2 ≈ param(4)
|
|
|
|
|
@test param(2)^2 ≈ 4
|
|
|
|
|
@test 4 ≈ param(2)^2
|
|
|
|
|
|
|
|
|
|
@test (param([1,2,3]) .< 2) == [true, false, false]
|
|
|
|
|
@test (param([1,2,3]) .<= 2) == [true, true, false]
|
|
|
|
|
@test (2 .> param([1,2,3])) == [true, false, false]
|
|
|
|
|
@test (2 .>= param([1,2,3])) == [true, true, false]
|
|
|
|
|
|
|
|
|
|
# TrackedArray
|
|
|
|
|
@test param([1,2,3]).^2 == param([1,4,9])
|
|
|
|
|
@test [1,2,3].^2 == param([1,4,9])
|
|
|
|
|
@test param([1,2,3]).^2 == [1,4,9]
|
|
|
|
|
|
|
|
|
|
@test param([1,2,3]).^2 ≈ param([1,4,9])
|
|
|
|
|
@test [1,2,3].^2 ≈ param([1,4,9])
|
|
|
|
|
@test param([1,2,3]).^2 ≈ [1,4,9]
|
|
|
|
|
end
|
2018-03-13 02:50:56 +00:00
|
|
|
|
|
2018-04-02 20:09:57 +00:00
|
|
|
|
@testset "reshape" begin
|
|
|
|
|
x = reshape(param(rand(2,2,2)), 4, 2)
|
|
|
|
|
@test x isa TrackedArray
|
|
|
|
|
@test size(x) == (4,2)
|
|
|
|
|
x = reshape(param([1]), (1,:))
|
|
|
|
|
@test x isa TrackedArray
|
|
|
|
|
@test size(x) == (1,1)
|
|
|
|
|
x = reshape(param(rand(2)), (2,:))
|
|
|
|
|
@test x isa TrackedArray
|
|
|
|
|
@test size(x) == (2,1)
|
|
|
|
|
x = reshape(param(rand(2,2)), (1,:,2))
|
|
|
|
|
@test x isa TrackedArray
|
|
|
|
|
@test size(x) == (1,2,2)
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-12 12:31:15 +00:00
|
|
|
|
@testset "Intermediates" begin
|
|
|
|
|
x = param([1])
|
|
|
|
|
l = sum((x .+ x).^2)
|
2018-10-26 15:57:19 +00:00
|
|
|
|
Flux.back!(l, once = false)
|
2018-02-12 12:31:15 +00:00
|
|
|
|
@test x.grad == [8]
|
|
|
|
|
x.grad .= 0
|
2018-10-26 15:57:19 +00:00
|
|
|
|
Flux.back!(l, once = false)
|
2018-02-12 12:31:15 +00:00
|
|
|
|
@test x.grad == [8]
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-13 10:20:38 +00:00
|
|
|
|
@testset "Fallbacks" begin
|
|
|
|
|
xs = param([1 2; 3 4])
|
|
|
|
|
@test similar(xs) isa Matrix{Float64}
|
|
|
|
|
end
|
|
|
|
|
|
2018-02-12 15:05:09 +00:00
|
|
|
|
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
|
|
|
|
|
2018-07-18 06:41:10 +00:00
|
|
|
|
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
2018-02-28 14:25:32 +00:00
|
|
|
|
|
2018-03-21 11:25:47 +00:00
|
|
|
|
b = param(rand())
|
|
|
|
|
Tracker.back!(b)
|
|
|
|
|
@test Tracker.grad(b) == 1
|
|
|
|
|
|
2018-06-06 16:01:28 +00:00
|
|
|
|
@testset "collect" begin
|
|
|
|
|
x, y = param(2), param(3)
|
|
|
|
|
xy = Tracker.collect([x, y])
|
|
|
|
|
@test xy isa TrackedArray{Float64}
|
|
|
|
|
z = xy[1]*xy[2]
|
|
|
|
|
back!(z)
|
|
|
|
|
@test grad.((x,y)) == (3, 2)
|
2018-08-07 21:09:20 +00:00
|
|
|
|
|
2019-01-15 15:48:38 +00:00
|
|
|
|
@test gradient(2, 3) do x, y
|
2018-08-07 21:09:20 +00:00
|
|
|
|
xy = Tracker.collect([x, y])
|
|
|
|
|
xy[1]*xy[2]
|
|
|
|
|
end == (3, 2)
|
2018-06-06 16:01:28 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-07-09 12:39:10 +00:00
|
|
|
|
# Gradient Hooks
|
|
|
|
|
@testset "Hooks" begin
|
|
|
|
|
x = param(2)
|
|
|
|
|
y = Tracker.hook(-, x)
|
|
|
|
|
back!(y)
|
|
|
|
|
@test grad(x) == -1
|
|
|
|
|
end
|
|
|
|
|
|
2018-07-09 16:52:34 +00:00
|
|
|
|
@testset "Checkpointing" begin
|
|
|
|
|
count = 0
|
|
|
|
|
function mul(a, b)
|
|
|
|
|
count += 1
|
|
|
|
|
a * b
|
|
|
|
|
end
|
2018-11-12 23:39:25 +00:00
|
|
|
|
@test gradient(x -> mul(5, x), 3)[1] == 5
|
2018-07-09 16:52:34 +00:00
|
|
|
|
@test count == 1
|
2018-11-12 23:39:25 +00:00
|
|
|
|
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
|
2018-07-09 16:52:34 +00:00
|
|
|
|
@test count == 3
|
|
|
|
|
end
|
|
|
|
|
|
2019-01-10 10:19:05 +00:00
|
|
|
|
@testset "Updates" begin
|
|
|
|
|
xs = param([1, 2, 3])
|
|
|
|
|
Tracker.update!(xs, param([4, 5, 6]))
|
|
|
|
|
@test xs == [5, 7, 9]
|
|
|
|
|
x = param(3)
|
|
|
|
|
Tracker.update!(x, param(4))
|
|
|
|
|
@test x == 7
|
|
|
|
|
end
|
|
|
|
|
|
2019-01-15 15:48:38 +00:00
|
|
|
|
@testset "Params" begin
|
|
|
|
|
W = param(randn(5, 10))
|
|
|
|
|
x = rand(10)
|
|
|
|
|
dW = gradient(W -> sum(W*x), W)[1]
|
|
|
|
|
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
|
|
|
|
|
@test gs[W] == dW
|
|
|
|
|
end
|
|
|
|
|
|
2019-01-25 16:13:34 +00:00
|
|
|
|
@testset "Forward" begin
|
|
|
|
|
@test @inferred(Tracker.forward_jacobian(x -> [sum(x)], rand(5,5), Val(12)))[2] ==
|
|
|
|
|
reshape(ones(25), :, 1)
|
|
|
|
|
@test gradient([2, 3]) do x
|
|
|
|
|
forwarddiff(x) do x
|
|
|
|
|
x[1]*x[2]
|
|
|
|
|
end
|
|
|
|
|
end == ([3, 2],)
|
|
|
|
|
end
|
|
|
|
|
|
2019-02-04 10:37:02 +00:00
|
|
|
|
@testset "Custom Sensitivities" begin
|
|
|
|
|
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
|
|
|
|
|
@test back([1, 1]) == (32,)
|
|
|
|
|
end
|
|
|
|
|
|
2017-10-12 08:56:23 +00:00
|
|
|
|
end #testset
|