Flux.jl/test/utils.jl

121 lines
3.1 KiB
Julia
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal, stack, unstack
using StatsBase: var
using Random
using Test
@testset "Throttle" begin
@testset "default behaviour" begin
a = []
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
f()
f()
f()
sleep(1.01)
@test length(a) == 1
end
@testset "leading behaviour" begin
a = []
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
f()
@test length(a) == 1
f()
@test length(a) == 1
sleep(1.01)
f()
@test length(a) == 2
end
@testset "trailing behaviour" begin
a = []
f = throttle(()->push!(a, time()), 1, leading=false, trailing=true)
f()
@test length(a) == 0
f()
@test length(a) == 0
sleep(1.01)
@test length(a) == 1
end
@testset "arguments" begin
a = []
f = throttle((x)->push!(a, x), 1, leading=true, trailing=true)
f(1)
@test a == [1]
f(2)
@test a == [1]
f(3)
@test a == [1]
sleep(1.01)
@test a == [1, 3]
end
end
@testset "Initialization" begin
# Set random seed so that these tests don't fail randomly
Random.seed!(0)
@testset "Fan in/out" begin
@test nfan() == (1, 1) #For a constant
@test nfan(100) == (1, 100) #For vector
@test nfan(100, 200) == (200, 100) #For Dense layer
@test nfan(2, 30, 40) == (2 * 30, 2 * 40) #For 1D Conv layer
@test nfan(2, 3, 40, 50) == (2 * 3 * 40, 2 * 3 * 50) #For 2D Conv layer
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
end
@testset "glorot" begin
# glorot_uniform and glorot_normal should both yield a kernel with
# variance ≈ 2/(fan_in + fan_out)
for dims [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
for init [glorot_uniform, glorot_normal]
v = init(dims...)
fan_in, fan_out = nfan(dims...)
σ2 = 2 / (fan_in + fan_out)
@test 0.9σ2 < var(v) < 1.1σ2
end
end
end
end
@testset "Params" begin
m = Dense(10, 5)
@test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
# Layer duplicated in same chain, params just once pls.
c = Chain(m, m)
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
end
@testset "Basic Stacking" begin
x = randn(3,3)
stacked = stack([x, x], 2)
@test size(stacked) == (3,2,3)
end
@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)
@test eltype(m[1].W) == Float32
@test eltype(m(x)) == Float32
@test eltype(f64(m)(x)) == Float64
@test eltype(f64(m)[1].W) == Float64
@test eltype(f32(f64(m))[1].W) == Float32
end
@testset "Stacking" begin
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@test unstack(stacked_array, 2) == unstacked_array
@test stack(unstacked_array, 2) == stacked_array
@test stack(unstack(stacked_array, 1), 1) == stacked_array
end