Update testsets

This commit is contained in:
Johnny Chen 2018-08-25 16:30:46 +08:00
parent 4baf85bbe2
commit b35664c59f

View File

@ -3,7 +3,7 @@ using Test, Random
@testset "basic" begin @testset "basic" begin
@testset "Chain" begin @testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2))(randn(10)) @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10)) @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
# numeric test should be put into testset of corresponding layer # numeric test should be put into testset of corresponding layer
end end
@ -14,11 +14,10 @@ using Test, Random
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting @test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting @test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
Random.seed!(0) @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == [10]
@test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168) @test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == [10 10]
Random.seed!(0) @test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == [10; 10]
@test all(Dense(10, 2)(randn(10)).data .≈ [ -0.3624741476779616 @test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
-0.46724765394534323])
end end
@ -27,8 +26,9 @@ using Test, Random
@test length(Flux.Diagonal(10)(1)) == 10 @test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10 @test length(Flux.Diagonal(10)(randn(1))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) @test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
Random.seed!(0)
@test all(Flux.Diagonal(2)(randn(2)).data .≈ [ 0.6791074260357777, @test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
0.8284134829000359]) @test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
end end
end end