batchnorm: make CuArrays happy

This commit is contained in:
Iblis Lin 2017-11-02 13:40:06 +08:00
parent 477da75428
commit 88bd8a8fbd

View File

@ -66,7 +66,7 @@ julia> m = Chain(
softmax)
Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax)
julia> opt = SGD(params(m), 10) # a crazy learning rate
julia> opt = SGD(params(m), 10, decay = .1) # a crazy learning rate
```
"""
mutable struct BatchNorm{F,V,N}
@ -85,6 +85,8 @@ BatchNorm(dims::Integer...; λ = identity,
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
function (BN::BatchNorm)(x)
λ, γ, β = BN.λ, BN.γ, BN.β
if !BN.active
μ = BN.μ
σ = BN.σ
@ -102,7 +104,7 @@ function (BN::BatchNorm)(x)
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
end
BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β)
λ.(γ .* ((x .- μ) ./ σ) .+ β)
end
children(BN::BatchNorm) =