This commit is contained in:
Avik Pal 2018-07-17 11:22:12 +05:30
parent 531ecccd38
commit 7dd5ec16c9
2 changed files with 4 additions and 5 deletions

View File

@ -187,7 +187,5 @@ batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
@grad function batchnorm(g, b, x, running_mean, running_var, momentum; kw...) @grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
y = batchnorm(data(g), data(b), data(x), running_mean, running_var, momentum; kw...) batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
y, Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.(g, b, x, Δ), running_mean, running_var, momentum; kw...)), nothing, nothing, nothing)
end

View File

@ -110,10 +110,11 @@ mutable struct BatchNorm
active::Bool active::Bool
end end
# NOTE: Keeping the ϵ smaller than 1e-5 is not supported by CUDNN
function BatchNorm(chs::Integer, λ = identity; function BatchNorm(chs::Integer, λ = identity;
initβ = x->zeros(Float32,x), initβ = x->zeros(Float32,x),
initγ = x->ones(Float32,x), initγ = x->ones(Float32,x),
ϵ = 1f-8, ϵ = 1f-5,
momentum = 0.1f0) momentum = 0.1f0)
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(Float32, chs), ones(Float32, chs), ϵ, momentum, true) zeros(Float32, chs), ones(Float32, chs), ϵ, momentum, true)