diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 6f1d8b9e..abcd6737 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -187,7 +187,5 @@ batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) -@grad function batchnorm(g, b, x, running_mean, running_var, momentum; kw...) - y = batchnorm(data(g), data(b), data(x), running_mean, running_var, momentum; kw...) - y, Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.(g, b, x, Δ), running_mean, running_var, momentum; kw...)), nothing, nothing, nothing) -end +@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) = + batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 40edaec6..44754815 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -110,10 +110,11 @@ mutable struct BatchNorm active::Bool end +# NOTE: Keeping the ϵ smaller than 1e-5 is not supported by CUDNN function BatchNorm(chs::Integer, λ = identity; initβ = x->zeros(Float32,x), initγ = x->ones(Float32,x), - ϵ = 1f-8, + ϵ = 1f-5, momentum = 0.1f0) BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), zeros(Float32, chs), ones(Float32, chs), ϵ, momentum, true)