Flux.jl/test/tracker.jl

255 lines
7.5 KiB
Julia
Raw Normal View History

2018-07-12 19:42:32 +00:00
using Flux
2017-08-23 01:03:17 +00:00
using Flux.Tracker, Base.Test, NNlib
2018-07-09 16:52:34 +00:00
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
2018-02-26 22:43:07 +00:00
using NNlib: conv
2018-07-18 06:41:10 +00:00
# using StatsBase
2017-08-23 00:43:45 +00:00
2018-02-05 18:10:02 +00:00
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
2017-08-23 00:43:45 +00:00
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
2017-08-23 00:43:45 +00:00
2018-07-18 06:41:10 +00:00
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
2017-11-08 22:00:19 +00:00
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
2018-07-18 06:41:10 +00:00
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
@test gradtest(x -> sum(x), randn(Float64,2,3))
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
2017-08-23 00:43:45 +00:00
2017-09-03 21:10:35 +00:00
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
2018-02-05 18:50:59 +00:00
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
2017-09-03 21:10:35 +00:00
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
2017-10-17 16:36:18 +00:00
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
2017-09-03 21:10:35 +00:00
@test gradtest(x -> x', rand(5))
2017-08-23 16:50:43 +00:00
function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
2018-05-02 12:57:32 +00:00
r2 = f(A, param(B), C)
if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat]
r3 = f(A, B, param(C))
else
2018-05-02 12:57:32 +00:00
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
end
r4 = f(param(A), param(B), param(C))
2018-05-02 06:30:11 +00:00
2018-05-02 12:57:32 +00:00
@test !isa(r0, TrackedArray)
@test all(isa.([r1,r2,r3,r4], TrackedArray))
@test r1 == r2 == r3 == r4
@test r0 == Flux.data(r4)
2018-05-02 06:30:11 +00:00
end
2018-05-02 13:54:40 +00:00
@testset "concat" begin
2018-07-12 19:42:32 +00:00
cat1(x...) = cat(x..., dims = 1)
cat2(x...) = cat(x..., dims = 2)
2018-05-02 06:30:11 +00:00
@testset for vcatf in [vcat, cat1]
2018-05-02 13:54:40 +00:00
@test gradtest(vcatf, rand(5), rand(3))
@test gradtest(vcatf, rand(5), rand(3), rand(8))
2018-05-02 12:57:32 +00:00
@test gradtest(vcatf, rand(5)', rand(5)')
2018-05-02 13:54:40 +00:00
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
2018-05-02 06:30:11 +00:00
@test gradtest(vcatf, rand(5), rand(3,1))
@test gradtest(vcatf, rand(5)', rand(2,5))
2018-05-02 13:54:40 +00:00
end
2018-05-02 06:30:11 +00:00
@testset for hcatf in [hcat, cat2]
2018-05-02 13:54:40 +00:00
@test gradtest(hcatf, rand(5), rand(5))
2018-05-02 12:57:32 +00:00
@test gradtest(hcatf, rand(5)', rand(5)')
2018-05-02 13:54:40 +00:00
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
2018-05-02 06:30:11 +00:00
@test gradtest(hcatf, rand(5)', rand(1,3))
@test gradtest(hcatf, rand(5), rand(5,2))
2018-05-02 12:57:32 +00:00
end
2018-07-12 19:42:32 +00:00
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
2018-05-02 12:57:32 +00:00
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3))
2018-05-02 13:54:40 +00:00
end
2018-05-02 06:30:11 +00:00
2018-07-12 19:42:32 +00:00
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
2018-05-02 06:30:11 +00:00
2018-05-02 12:57:32 +00:00
@testset "cat($dim, ...)" for dim in 3:5
2018-07-12 19:42:32 +00:00
catdim = (x...) -> cat(x..., dims = dim)
2018-05-02 12:57:32 +00:00
@test gradtest(catdim, rand(5), rand(5), rand(5))
2018-05-02 13:54:40 +00:00
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
2018-05-02 12:57:32 +00:00
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
2018-05-02 13:54:40 +00:00
end
2018-05-02 12:57:32 +00:00
@test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(1,rand(2)), TrackedArray)
2018-07-12 19:42:32 +00:00
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin
2018-07-18 07:01:06 +00:00
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
2018-05-02 12:57:32 +00:00
promotiontest(fcat, rand(2), rand(2), rand(2))
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
end
2018-05-02 12:57:32 +00:00
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end
2018-05-02 13:54:40 +00:00
end
2018-02-28 02:19:58 +00:00
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
2017-09-05 06:11:28 +00:00
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron, rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
2018-02-05 18:29:35 +00:00
@test gradtest(diagm, rand(3))
@testset "mean" begin
@test gradtest(mean, rand(2, 3))
@test gradtest(x -> mean(x, 1), rand(2, 3))
@test gradtest(x -> mean(x, 2), rand(2, 3))
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, 1), rand(2, 3))
@test gradtest(x -> maximum(x, 2), rand(2, 3))
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
end
@testset "minimum" begin
@test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, 1), rand(2, 3))
@test gradtest(x -> minimum(x, 2), rand(2, 3))
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
end
2017-11-21 16:04:04 +00:00
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
2017-12-12 17:23:15 +00:00
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
2018-02-13 13:31:35 +00:00
@test gradtest(dot, rand(5), rand(5))
2017-12-12 17:23:15 +00:00
2018-03-05 17:24:46 +00:00
@test gradtest(vecnorm, rand(5))
2017-09-07 03:09:32 +00:00
@test gradtest(rand(5)) do x
y = x.^2
2y + x
end
2018-07-18 06:41:10 +00:00
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
2018-02-28 13:00:38 +00:00
2018-02-26 22:43:07 +00:00
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
2018-02-28 13:00:38 +00:00
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
2018-02-26 22:43:07 +00:00
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
2018-02-28 13:00:38 +00:00
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
2017-12-14 18:48:38 +00:00
2018-01-15 17:00:47 +00:00
@test (param([1,2,3]) .< 2) == [true, false, false]
2018-03-13 02:50:56 +00:00
@test param(2)^2 == 4.0
2018-04-02 20:09:57 +00:00
@testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2)
@test x isa TrackedArray
@test size(x) == (4,2)
x = reshape(param([1]), (1,:))
@test x isa TrackedArray
@test size(x) == (1,1)
x = reshape(param(rand(2)), (2,:))
@test x isa TrackedArray
@test size(x) == (2,1)
x = reshape(param(rand(2,2)), (1,:,2))
@test x isa TrackedArray
@test size(x) == (1,2,2)
end
2018-02-12 12:31:15 +00:00
@testset "Intermediates" begin
x = param([1])
l = sum((x .+ x).^2)
Flux.back!(l)
@test x.grad == [8]
x.grad .= 0
Flux.back!(l)
@test x.grad == [8]
end
2018-02-13 10:20:38 +00:00
@testset "Fallbacks" begin
xs = param([1 2; 3 4])
@test similar(xs) isa Matrix{Float64}
end
2018-02-12 15:05:09 +00:00
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
2018-07-18 06:41:10 +00:00
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
2018-02-28 14:25:32 +00:00
2018-03-21 11:25:47 +00:00
b = param(rand())
Tracker.back!(b)
@test Tracker.grad(b) == 1
2018-06-06 16:01:28 +00:00
@testset "collect" begin
x, y = param(2), param(3)
xy = Tracker.collect([x, y])
@test xy isa TrackedArray{Float64}
z = xy[1]*xy[2]
back!(z)
@test grad.((x,y)) == (3, 2)
end
2018-07-09 12:39:10 +00:00
# Gradient Hooks
@testset "Hooks" begin
x = param(2)
y = Tracker.hook(-, x)
back!(y)
@test grad(x) == -1
end
2018-07-09 16:52:34 +00:00
@testset "Checkpointing" begin
count = 0
function mul(a, b)
count += 1
a * b
end
@test derivative(x -> mul(5, x), 3) == 5
@test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
@test count == 3
end
2017-10-12 08:56:23 +00:00
end #testset