This commit is contained in:
Johnny Chen 2018-08-23 21:34:11 +08:00
parent dfe7578216
commit 6743d52d08
3 changed files with 34 additions and 1 deletions

View File

@ -75,10 +75,11 @@ end
@treelike Dense
function (a::Dense)(x)
function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
end
(a::Dense)(x::Number) = a([x]) # prevent broadcasting of scalar
function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))

31
test/layers/basic.jl Normal file
View File

@ -0,0 +1,31 @@
using Test, Random
@testset "basic" begin
@testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax)
@test_nowarn Chain(Dense(10, 5, σ),Dense(5, 2), softmax)(randn(10))
end
@testset "Dense" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
Random.seed!(0)
@test all(Dense(10, 1)(randn(10)).data .≈ 1.1774348382231168)
Random.seed!(0)
@test all(Dense(10, 2)(randn(10)).data .≈ [ -0.3624741476779616
-0.46724765394534323])
@test_throws DimensionMismatch Dense(10, 5)(1)
end
@testset "Diagonal" begin
@test length(Flux.Diagonal(10)(randn(10))) == 10
@test length(Flux.Diagonal(10)(1)) == 10
@test length(Flux.Diagonal(10)(randn(1))) == 10
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
Random.seed!(0)
@test all(Flux.Diagonal(2)(randn(2)).data .≈ [ 0.6791074260357777,
0.8284134829000359])
end
end

View File

@ -25,6 +25,7 @@ insert!(LOAD_PATH, 2, "@v#.#")
include("utils.jl")
include("tracker.jl")
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl")