Merge pull request #570 from avik-pal/ap/batchnorm_fixes
Patches for default initializers
This commit is contained in:
commit
013b421b08
@ -83,12 +83,12 @@ DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||||||
stride = 1, pad = 0) where {T,N} =
|
stride = 1, pad = 0) where {T,N} =
|
||||||
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)
|
||||||
|
|
||||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
|
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
||||||
stride = 1, pad = 0) where N =
|
stride = 1, pad = 0) where N =
|
||||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||||
stride = stride, pad = pad)
|
stride = stride, pad = pad)
|
||||||
|
|
||||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
||||||
stride::NTuple{N,Integer} = map(_->1,k),
|
stride::NTuple{N,Integer} = map(_->1,k),
|
||||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||||
|
@ -106,7 +106,7 @@ mutable struct BatchNorm{F,V,W,N}
|
|||||||
end
|
end
|
||||||
|
|
||||||
BatchNorm(chs::Integer, λ = identity;
|
BatchNorm(chs::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
|
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||||
|
|
||||||
|
@ -21,3 +21,15 @@ end
|
|||||||
|
|
||||||
@test size(m(r)) == (10, 5)
|
@test size(m(r)) == (10, 5)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Depthwise Conv" begin
|
||||||
|
r = zeros(Float32, 28, 28, 3, 5)
|
||||||
|
|
||||||
|
m1 = DepthwiseConv((2, 2), 3=>5)
|
||||||
|
|
||||||
|
@test size(m1(r), 3) == 15
|
||||||
|
|
||||||
|
m2 = DepthwiseConv((2, 2), 3)
|
||||||
|
|
||||||
|
@test size(m2(r), 3) == 3
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user