Require weight and bias to be AbstractArrays
This commit is contained in:
parent
e1f80d4627
commit
df84628c29
|
@ -102,7 +102,7 @@ julia> d(rand(5))
|
|||
-0.16210233
|
||||
0.12311903```
|
||||
"""
|
||||
struct Dense{F,S,T}
|
||||
struct Dense{F,S<:AbstractArray,T<:AbstractArray}
|
||||
W::S
|
||||
b::T
|
||||
σ::F
|
||||
|
|
|
@ -28,6 +28,13 @@ import Flux: activations
|
|||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@testset "constructors" begin
|
||||
@test size(Dense(10, 100).W) == (100, 10)
|
||||
@test Dense(rand(100,10), rand(10)).σ == identity
|
||||
|
||||
@test_throws MethodError Dense(10, 10.5)
|
||||
end
|
||||
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
|
@ -37,7 +44,6 @@ import Flux: activations
|
|||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
|
||||
end
|
||||
|
||||
@testset "Diagonal" begin
|
||||
|
|
Loading…
Reference in New Issue