Fix
This commit is contained in:
parent
531ecccd38
commit
7dd5ec16c9
@ -187,7 +187,5 @@ batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
@grad function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
|
||||
y = batchnorm(data(g), data(b), data(x), running_mean, running_var, momentum; kw...)
|
||||
y, Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.(g, b, x, Δ), running_mean, running_var, momentum; kw...)), nothing, nothing, nothing)
|
||||
end
|
||||
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||
|
@ -110,10 +110,11 @@ mutable struct BatchNorm
|
||||
active::Bool
|
||||
end
|
||||
|
||||
# NOTE: Keeping the ϵ smaller than 1e-5 is not supported by CUDNN
|
||||
function BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = x->zeros(Float32,x),
|
||||
initγ = x->ones(Float32,x),
|
||||
ϵ = 1f-8,
|
||||
ϵ = 1f-5,
|
||||
momentum = 0.1f0)
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(Float32, chs), ones(Float32, chs), ϵ, momentum, true)
|
||||
|
Loading…
Reference in New Issue
Block a user