batchnorm: make CuArrays happy
This commit is contained in:
parent
477da75428
commit
88bd8a8fbd
@ -66,7 +66,7 @@ julia> m = Chain(
|
||||
softmax)
|
||||
Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax)
|
||||
|
||||
julia> opt = SGD(params(m), 10) # a crazy learning rate
|
||||
julia> opt = SGD(params(m), 10, decay = .1) # a crazy learning rate
|
||||
```
|
||||
"""
|
||||
mutable struct BatchNorm{F,V,N}
|
||||
@ -85,6 +85,8 @@ BatchNorm(dims::Integer...; λ = identity,
|
||||
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
|
||||
|
||||
function (BN::BatchNorm)(x)
|
||||
λ, γ, β = BN.λ, BN.γ, BN.β
|
||||
|
||||
if !BN.active
|
||||
μ = BN.μ
|
||||
σ = BN.σ
|
||||
@ -102,7 +104,7 @@ function (BN::BatchNorm)(x)
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
|
||||
end
|
||||
|
||||
BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β)
|
||||
λ.(γ .* ((x .- μ) ./ σ) .+ β)
|
||||
end
|
||||
|
||||
children(BN::BatchNorm) =
|
||||
|
Loading…
Reference in New Issue
Block a user