Fix issue #354
This commit is contained in:
parent
dfe7578216
commit
6743d52d08
|
@ -75,10 +75,11 @@ end
|
|||
|
||||
@treelike Dense
|
||||
|
||||
function (a::Dense)(x)
|
||||
function (a::Dense)(x::AbstractArray)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
σ.(W*x .+ b)
|
||||
end
|
||||
(a::Dense)(x::Number) = a([x]) # prevent broadcasting of scalar
|
||||
|
||||
function Base.show(io::IO, l::Dense)
|
||||
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
using Test, Random
|
||||
|
||||
|
||||
@testset "basic" begin
|
||||
@testset "Chain" begin
|
||||
@test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax)
|
||||
@test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax)(randn(10))
|
||||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
Random.seed!(0)
|
||||
@test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168)
|
||||
Random.seed!(0)
|
||||
@test all(Dense(10, 2)(randn(10)).data .≈ [ -0.3624741476779616
|
||||
-0.46724765394534323])
|
||||
|
||||
@test_throws DimensionMismatch Dense(10, 5)(1)
|
||||
end
|
||||
|
||||
@testset "Diagonal" begin
|
||||
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||
@test length(Flux.Diagonal(10)(1)) == 10
|
||||
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||
Random.seed!(0)
|
||||
@test all(Flux.Diagonal(2)(randn(2)).data .≈ [ 0.6791074260357777,
|
||||
0.8284134829000359])
|
||||
end
|
||||
end
|
|
@ -25,6 +25,7 @@ insert!(LOAD_PATH, 2, "@v#.#")
|
|||
|
||||
include("utils.jl")
|
||||
include("tracker.jl")
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("optimise.jl")
|
||||
|
|
Loading…
Reference in New Issue