Merge #1218
1218: Require weight and bias to be AbstractArrays r=CarloLucibello a=oxinabox closes #1199 While in theory someone could be using Dense with weights and biases that are not abstract arrays, I would be surprised. So allowing it is just leaving a food-gun laying around. If it is common then we can instead close #1199 by adding a special constructor for `Number` subtypes that error if they are not integers, or something a long those lines. ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md I think this is a bug-fix thus the following are not required: - [ ] Documentation, if applicable - [ ] Final review from `@MikeInnes` or `@dhairyagandhi96` (for API changes). Co-authored-by: Lyndon White <lyndon.white@invenialabs.co.uk> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
This commit is contained in:
commit
97406507fd
1
NEWS.md
1
NEWS.md
|
@ -1,5 +1,6 @@
|
|||
# v0.11
|
||||
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
|
||||
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].
|
||||
|
||||
# v0.10.5
|
||||
* Add option for [same padding](https://github.com/FluxML/Flux.jl/pull/901) to conv and pooling layers by setting `pad=SamePad()`.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
name = "Flux"
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.11.0"
|
||||
version = "0.11.0-DEV"
|
||||
|
||||
[deps]
|
||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
|
|
|
@ -102,7 +102,7 @@ julia> d(rand(5))
|
|||
-0.16210233
|
||||
0.12311903```
|
||||
"""
|
||||
struct Dense{F,S,T}
|
||||
struct Dense{F,S<:AbstractArray,T<:AbstractArray}
|
||||
W::S
|
||||
b::T
|
||||
σ::F
|
||||
|
|
|
@ -28,6 +28,14 @@ import Flux: activations
|
|||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@testset "constructors" begin
|
||||
@test size(Dense(10, 100).W) == (100, 10)
|
||||
@test Dense(rand(100,10), rand(10)).σ == identity
|
||||
|
||||
@test_throws MethodError Dense(10, 10.5)
|
||||
@test_throws MethodError Dense(10, 10.5, tanh)
|
||||
end
|
||||
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
|
@ -37,7 +45,6 @@ import Flux: activations
|
|||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
|
||||
end
|
||||
|
||||
@testset "Diagonal" begin
|
||||
|
|
Loading…
Reference in New Issue