Fix issue #354
This commit is contained in:
parent
dfe7578216
commit
6743d52d08
@ -75,10 +75,11 @@ end
|
|||||||
|
|
||||||
@treelike Dense
|
@treelike Dense
|
||||||
|
|
||||||
function (a::Dense)(x)
|
function (a::Dense)(x::AbstractArray)
|
||||||
W, b, σ = a.W, a.b, a.σ
|
W, b, σ = a.W, a.b, a.σ
|
||||||
σ.(W*x .+ b)
|
σ.(W*x .+ b)
|
||||||
end
|
end
|
||||||
|
(a::Dense)(x::Number) = a([x]) # prevent broadcasting of scalar
|
||||||
|
|
||||||
function Base.show(io::IO, l::Dense)
|
function Base.show(io::IO, l::Dense)
|
||||||
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
|
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
|
||||||
|
31
test/layers/basic.jl
Normal file
31
test/layers/basic.jl
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
using Test, Random
|
||||||
|
|
||||||
|
|
||||||
|
@testset "basic" begin
|
||||||
|
@testset "Chain" begin
|
||||||
|
@test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax)
|
||||||
|
@test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax)(randn(10))
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Dense" begin
|
||||||
|
@test length(Dense(10, 5)(randn(10))) == 5
|
||||||
|
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||||
|
Random.seed!(0)
|
||||||
|
@test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168)
|
||||||
|
Random.seed!(0)
|
||||||
|
@test all(Dense(10, 2)(randn(10)).data .≈ [ -0.3624741476779616
|
||||||
|
-0.46724765394534323])
|
||||||
|
|
||||||
|
@test_throws DimensionMismatch Dense(10, 5)(1)
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Diagonal" begin
|
||||||
|
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||||
|
@test length(Flux.Diagonal(10)(1)) == 10
|
||||||
|
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||||
|
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||||
|
Random.seed!(0)
|
||||||
|
@test all(Flux.Diagonal(2)(randn(2)).data .≈ [ 0.6791074260357777,
|
||||||
|
0.8284134829000359])
|
||||||
|
end
|
||||||
|
end
|
@ -25,6 +25,7 @@ insert!(LOAD_PATH, 2, "@v#.#")
|
|||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
|
include("layers/basic.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
|
Loading…
Reference in New Issue
Block a user