Merge pull request #570 from avik-pal/ap/batchnorm_fixes
Patches for default initializers
This commit is contained in:
commit
013b421b08
|
@ -83,12 +83,12 @@ DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
stride = 1, pad = 0) where {T,N} =
|
||||
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0) where N =
|
||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||
|
|
|
@ -106,7 +106,7 @@ mutable struct BatchNorm{F,V,W,N}
|
|||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
|
|
|
@ -21,3 +21,15 @@ end
|
|||
|
||||
@test size(m(r)) == (10, 5)
|
||||
end
|
||||
|
||||
@testset "Depthwise Conv" begin
|
||||
r = zeros(Float32, 28, 28, 3, 5)
|
||||
|
||||
m1 = DepthwiseConv((2, 2), 3=>5)
|
||||
|
||||
@test size(m1(r), 3) == 15
|
||||
|
||||
m2 = DepthwiseConv((2, 2), 3)
|
||||
|
||||
@test size(m2(r), 3) == 3
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue