Fix layers errors

This commit is contained in:
Avik Pal 2018-08-11 17:20:27 +05:30
parent 5186e3ba18
commit d3c78a80be
4 changed files with 16 additions and 10 deletions

View File

@ -130,13 +130,13 @@ function (BN::BatchNorm)(x)
ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, axes)
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
μ = mean(x, dims = axes)
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
# update moving mean/std
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...,))
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...,)) .* m ./ (m - 1)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
end
let λ = BN.λ

View File

@ -53,17 +53,17 @@ end
# .1 * 4 + 0 = .4
@test m.μ reshape([0.3, 0.4], 2, 1)
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 2×1 Array{Float64,2}:
# 1.14495
# 1.14495
@test m.σ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
@test m.σ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
testmode!(m)
@test !m.active
x = m(x).data
@test x[1] (1 - 0.3) / 1.1449489742783179
@test x[1] (1 .- 0.3) / 1.1449489742783179
end
# with activation function

View File

@ -42,8 +42,8 @@ const ϵ = 1e-7
logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))
end
@testset "logitbinarycrossentropy" begin

View File

@ -5,11 +5,17 @@ Random.seed!(0)
@testset "Flux" begin
println("Testing")
include("utils.jl")
include("tracker.jl")
# println("Testing")
# include("tracker.jl")
println("Testing")
include("layers/normalisation.jl")
println("Testing")
include("layers/stateless.jl")
println("Testing")
include("optimise.jl")
println("Testing")
include("data.jl")
# if Base.find_in_path("CuArrays") ≠ nothing