batchnorm: make CuArrays happy
This commit is contained in:
parent
477da75428
commit
88bd8a8fbd
@ -66,7 +66,7 @@ julia> m = Chain(
|
|||||||
softmax)
|
softmax)
|
||||||
Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax)
|
Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax)
|
||||||
|
|
||||||
julia> opt = SGD(params(m), 10) # a crazy learning rate
|
julia> opt = SGD(params(m), 10, decay = .1) # a crazy learning rate
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
mutable struct BatchNorm{F,V,N}
|
mutable struct BatchNorm{F,V,N}
|
||||||
@ -85,6 +85,8 @@ BatchNorm(dims::Integer...; λ = identity,
|
|||||||
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
|
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
|
||||||
|
|
||||||
function (BN::BatchNorm)(x)
|
function (BN::BatchNorm)(x)
|
||||||
|
λ, γ, β = BN.λ, BN.γ, BN.β
|
||||||
|
|
||||||
if !BN.active
|
if !BN.active
|
||||||
μ = BN.μ
|
μ = BN.μ
|
||||||
σ = BN.σ
|
σ = BN.σ
|
||||||
@ -102,7 +104,7 @@ function (BN::BatchNorm)(x)
|
|||||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
|
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β)
|
λ.(γ .* ((x .- μ) ./ σ) .+ β)
|
||||||
end
|
end
|
||||||
|
|
||||||
children(BN::BatchNorm) =
|
children(BN::BatchNorm) =
|
||||||
|
Loading…
Reference in New Issue
Block a user