2018-08-23 13:34:11 +00:00
|
|
|
|
using Test, Random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@testset "basic" begin
|
|
|
|
|
@testset "Chain" begin
|
2018-08-23 13:59:41 +00:00
|
|
|
|
@test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2))(randn(10))
|
|
|
|
|
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
|
|
|
|
# numeric test should be put into testset of corresponding layer
|
2018-08-23 13:34:11 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "Dense" begin
|
|
|
|
|
@test length(Dense(10, 5)(randn(10))) == 5
|
|
|
|
|
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
2018-08-23 13:59:41 +00:00
|
|
|
|
@test_throws DimensionMismatch Dense(10, 5)(1) # avoid broadcasting
|
|
|
|
|
@test_throws DimensionMismatch Dense(10, 5).(randn(10)) # avoid broadcasting
|
|
|
|
|
|
2018-08-23 13:34:11 +00:00
|
|
|
|
Random.seed!(0)
|
|
|
|
|
@test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168)
|
|
|
|
|
Random.seed!(0)
|
|
|
|
|
@test all(Dense(10, 2)(randn(10)).data .≈ [ -0.3624741476779616
|
|
|
|
|
-0.46724765394534323])
|
|
|
|
|
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "Diagonal" begin
|
|
|
|
|
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
|
|
|
|
@test length(Flux.Diagonal(10)(1)) == 10
|
|
|
|
|
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
|
|
|
|
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
|
|
|
|
Random.seed!(0)
|
|
|
|
|
@test all(Flux.Diagonal(2)(randn(2)).data .≈ [ 0.6791074260357777,
|
|
|
|
|
0.8284134829000359])
|
|
|
|
|
end
|
|
|
|
|
end
|