Require weight and bias to be AbstractArrays
This commit is contained in:
parent
e1f80d4627
commit
df84628c29
|
@ -102,7 +102,7 @@ julia> d(rand(5))
|
||||||
-0.16210233
|
-0.16210233
|
||||||
0.12311903```
|
0.12311903```
|
||||||
"""
|
"""
|
||||||
struct Dense{F,S,T}
|
struct Dense{F,S<:AbstractArray,T<:AbstractArray}
|
||||||
W::S
|
W::S
|
||||||
b::T
|
b::T
|
||||||
σ::F
|
σ::F
|
||||||
|
|
|
@ -28,6 +28,13 @@ import Flux: activations
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Dense" begin
|
@testset "Dense" begin
|
||||||
|
@testset "constructors" begin
|
||||||
|
@test size(Dense(10, 100).W) == (100, 10)
|
||||||
|
@test Dense(rand(100,10), rand(10)).σ == identity
|
||||||
|
|
||||||
|
@test_throws MethodError Dense(10, 10.5)
|
||||||
|
end
|
||||||
|
|
||||||
@test length(Dense(10, 5)(randn(10))) == 5
|
@test length(Dense(10, 5)(randn(10))) == 5
|
||||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||||
|
@ -37,7 +44,6 @@ import Flux: activations
|
||||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Diagonal" begin
|
@testset "Diagonal" begin
|
||||||
|
|
Loading…
Reference in New Issue