Merge pull request #372 from johnnychen94/issue-#354
Type restriction for Dense layer
This commit is contained in:
commit
73385b5dbd
|
@ -75,7 +75,7 @@ end
|
|||
|
||||
@treelike Dense
|
||||
|
||||
function (a::Dense)(x)
|
||||
function (a::Dense)(x::AbstractArray)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
σ.(W*x .+ b)
|
||||
end
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
using Test, Random
|
||||
|
||||
@testset "basic" begin
|
||||
@testset "Chain" begin
|
||||
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||
# numeric test should be put into testset of corresponding layer
|
||||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
|
||||
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
|
||||
end
|
||||
|
||||
@testset "Diagonal" begin
|
||||
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||
@test length(Flux.Diagonal(10)(1)) == 10
|
||||
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||
|
||||
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
|
||||
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||
end
|
||||
end
|
|
@ -32,6 +32,7 @@ include("data.jl")
|
|||
|
||||
@info "Testing Layers"
|
||||
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
|
|
Loading…
Reference in New Issue