2018-08-11 12:54:59 +00:00
|
|
|
|
using Flux
|
2019-11-20 04:20:42 +00:00
|
|
|
|
using Flux: throttle, nfan, glorot_uniform, glorot_normal, stack, unstack
|
2019-11-19 07:50:40 +00:00
|
|
|
|
using StatsBase: var
|
2018-08-11 09:51:07 +00:00
|
|
|
|
using Random
|
2018-08-11 12:54:59 +00:00
|
|
|
|
using Test
|
2017-07-26 01:57:20 +00:00
|
|
|
|
|
2017-08-19 19:20:20 +00:00
|
|
|
|
@testset "Throttle" begin
|
2017-07-26 01:57:20 +00:00
|
|
|
|
@testset "default behaviour" begin
|
|
|
|
|
a = []
|
2018-08-17 10:44:01 +00:00
|
|
|
|
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
2017-07-26 01:57:20 +00:00
|
|
|
|
f()
|
|
|
|
|
f()
|
|
|
|
|
f()
|
|
|
|
|
sleep(1.01)
|
|
|
|
|
@test length(a) == 1
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "leading behaviour" begin
|
|
|
|
|
a = []
|
2018-08-17 10:44:01 +00:00
|
|
|
|
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
2017-07-26 01:57:20 +00:00
|
|
|
|
f()
|
|
|
|
|
@test length(a) == 1
|
|
|
|
|
f()
|
|
|
|
|
@test length(a) == 1
|
|
|
|
|
sleep(1.01)
|
|
|
|
|
f()
|
|
|
|
|
@test length(a) == 2
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "trailing behaviour" begin
|
|
|
|
|
a = []
|
2018-08-17 10:44:01 +00:00
|
|
|
|
f = throttle(()->push!(a, time()), 1, leading=false, trailing=true)
|
2017-07-26 01:57:20 +00:00
|
|
|
|
f()
|
|
|
|
|
@test length(a) == 0
|
|
|
|
|
f()
|
|
|
|
|
@test length(a) == 0
|
|
|
|
|
sleep(1.01)
|
|
|
|
|
@test length(a) == 1
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "arguments" begin
|
|
|
|
|
a = []
|
|
|
|
|
f = throttle((x)->push!(a, x), 1, leading=true, trailing=true)
|
|
|
|
|
f(1)
|
|
|
|
|
@test a == [1]
|
|
|
|
|
f(2)
|
|
|
|
|
@test a == [1]
|
|
|
|
|
f(3)
|
|
|
|
|
@test a == [1]
|
|
|
|
|
sleep(1.01)
|
|
|
|
|
@test a == [1, 3]
|
|
|
|
|
end
|
|
|
|
|
end
|
2017-12-08 13:46:12 +00:00
|
|
|
|
|
2017-12-05 07:47:03 +00:00
|
|
|
|
@testset "Initialization" begin
|
|
|
|
|
# Set random seed so that these tests don't fail randomly
|
2018-08-11 09:51:07 +00:00
|
|
|
|
Random.seed!(0)
|
2017-12-05 07:47:03 +00:00
|
|
|
|
|
2019-11-19 07:50:40 +00:00
|
|
|
|
@testset "Fan in/out" begin
|
2019-11-19 09:16:29 +00:00
|
|
|
|
@test nfan() == (1, 1) #For a constant
|
|
|
|
|
@test nfan(100) == (1, 100) #For vector
|
|
|
|
|
@test nfan(100, 200) == (200, 100) #For Dense layer
|
2019-11-19 07:50:40 +00:00
|
|
|
|
@test nfan(2, 30, 40) == (2 * 30, 2 * 40) #For 1D Conv layer
|
|
|
|
|
@test nfan(2, 3, 40, 50) == (2 * 3 * 40, 2 * 3 * 50) #For 2D Conv layer
|
|
|
|
|
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "glorot" begin
|
|
|
|
|
# glorot_uniform and glorot_normal should both yield a kernel with
|
|
|
|
|
# variance ≈ 2/(fan_in + fan_out)
|
2019-11-19 09:16:29 +00:00
|
|
|
|
for dims ∈ [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
|
2019-11-19 07:50:40 +00:00
|
|
|
|
for init ∈ [glorot_uniform, glorot_normal]
|
|
|
|
|
v = init(dims...)
|
|
|
|
|
fan_in, fan_out = nfan(dims...)
|
|
|
|
|
σ2 = 2 / (fan_in + fan_out)
|
|
|
|
|
@test 0.9σ2 < var(v) < 1.1σ2
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
end
|
2017-12-13 17:06:23 +00:00
|
|
|
|
end
|
2018-02-08 16:13:20 +00:00
|
|
|
|
|
|
|
|
|
@testset "Params" begin
|
|
|
|
|
m = Dense(10, 5)
|
|
|
|
|
@test size.(params(m)) == [(5, 10), (5,)]
|
|
|
|
|
m = RNN(10, 5)
|
|
|
|
|
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
2019-07-08 21:11:35 +00:00
|
|
|
|
|
|
|
|
|
# Layer duplicated in same chain, params just once pls.
|
|
|
|
|
c = Chain(m, m)
|
|
|
|
|
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
|
|
|
|
|
|
2019-07-08 23:15:55 +00:00
|
|
|
|
# Self-referential array. Just want params, no stack overflow pls.
|
|
|
|
|
r = Any[nothing,m]
|
|
|
|
|
r[1] = r
|
2019-07-08 21:11:35 +00:00
|
|
|
|
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
|
2018-02-08 16:13:20 +00:00
|
|
|
|
end
|
2018-12-20 15:03:21 +00:00
|
|
|
|
|
2019-02-15 14:50:01 +00:00
|
|
|
|
@testset "Basic Stacking" begin
|
2018-12-20 15:03:21 +00:00
|
|
|
|
x = randn(3,3)
|
|
|
|
|
stacked = stack([x, x], 2)
|
|
|
|
|
@test size(stacked) == (3,2,3)
|
|
|
|
|
end
|
2019-02-15 14:50:01 +00:00
|
|
|
|
|
2019-01-25 10:06:37 +00:00
|
|
|
|
@testset "Precision" begin
|
|
|
|
|
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
|
|
|
|
x = rand(10)
|
2019-03-08 14:49:28 +00:00
|
|
|
|
@test eltype(m[1].W) == Float32
|
|
|
|
|
@test eltype(m(x)) == Float32
|
|
|
|
|
@test eltype(f64(m)(x)) == Float64
|
|
|
|
|
@test eltype(f64(m)[1].W) == Float64
|
|
|
|
|
@test eltype(f32(f64(m))[1].W) == Float32
|
2019-01-25 10:06:37 +00:00
|
|
|
|
end
|
2019-01-29 09:41:15 +00:00
|
|
|
|
|
|
|
|
|
@testset "Stacking" begin
|
|
|
|
|
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
|
|
|
|
|
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
|
|
|
|
|
@test unstack(stacked_array, 2) == unstacked_array
|
|
|
|
|
@test stack(unstacked_array, 2) == stacked_array
|
|
|
|
|
@test stack(unstack(stacked_array, 1), 1) == stacked_array
|
2019-02-15 15:03:21 +00:00
|
|
|
|
end
|