Get the batchnorm working without cache
This commit is contained in:
parent
4916c8e6da
commit
8f43258ab7
@ -1,6 +1,6 @@
|
|||||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||||
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
|
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
|
||||||
import Flux.data
|
import ..Flux: data
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Void}
|
||||||
@ -63,7 +63,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
end
|
end
|
||||||
xd = TensorDesc(x)
|
xd = TensorDesc(x)
|
||||||
yd = TensorDesc(y)
|
yd = TensorDesc(y)
|
||||||
gd = TensorDesc(T, (1,1,length(g),1))
|
gd = TensorDesc(T, dims)
|
||||||
|
|
||||||
if training
|
if training
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
xd = TensorDesc(x)
|
xd = TensorDesc(x)
|
||||||
dyd = TensorDesc(dy)
|
dyd = TensorDesc(dy)
|
||||||
dxd = TensorDesc(dx)
|
dxd = TensorDesc(dx)
|
||||||
gd = TensorDesc(T, (1,1,length(g),1))
|
gd = TensorDesc(T, _wsize(x))
|
||||||
if cache !== nothing
|
if cache !== nothing
|
||||||
mean, ivar = cache.mean, cache.ivar
|
mean, ivar = cache.mean, cache.ivar
|
||||||
info("mean and ivar are fetched from the cache")
|
info("mean and ivar are fetched from the cache")
|
||||||
@ -167,9 +167,9 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
gd, g, dg, db,
|
gd, g, dg, db,
|
||||||
eps, mean, ivar)
|
eps, mean, ivar)
|
||||||
else
|
else
|
||||||
ivar = 1 ./ sqrt.(reshape(running_var, (1, 1, length(running_var), 1)) .+ eps)
|
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
|
||||||
dx .= dy .* reshape(g, (1, 1, length(g), 1)) .* ivar
|
dx .= dy .* reshape(g, _wsize(x)) .* ivar
|
||||||
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, (1, 1, length(running_mean), 1))) .* ivar, _reddims(dy)), (1,2,4))
|
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), (1,2,4))
|
||||||
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
|
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -179,6 +179,13 @@ end
|
|||||||
import ..Flux: Flux
|
import ..Flux: Flux
|
||||||
import ..Tracker: track, back, @back, istracked, TrackedArray
|
import ..Tracker: track, back, @back, istracked, TrackedArray
|
||||||
|
|
||||||
|
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||||
|
CuParam45{T} = Union{CuParam{T,4},CuParam{T,5}}
|
||||||
|
CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)},<:CuParam{T,1},<:CuParam{T,1},<:T}
|
||||||
|
|
||||||
|
(BN::BatchNorm)(x::CuParam45{T}) =
|
||||||
|
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = nothing, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
||||||
|
|
||||||
_batchnorm(g, b, x, running_mean, running_var, momentum,
|
_batchnorm(g, b, x, running_mean, running_var, momentum,
|
||||||
cache, alpha, beta, eps, training) =
|
cache, alpha, beta, eps, training) =
|
||||||
batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
|
batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
|
||||||
|
@ -104,7 +104,6 @@ mutable struct BatchNorm{F,V,W,N}
|
|||||||
σ::W # moving std
|
σ::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
cache
|
|
||||||
active::Bool
|
active::Bool
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -113,44 +112,39 @@ BatchNorm(chs::Integer, λ = identity;
|
|||||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, nothing, true)
|
zeros(chs), ones(chs), ϵ, momentum, nothing, true)
|
||||||
|
|
||||||
|
function (BN::BatchNorm)(x)
|
||||||
function batchnorm(γ, β, x, μ, σ, momentum; cache = nothing, alpha = 1, beta = 0, eps = 1.0e-5, training = true)
|
size(x, ndims(x)-1) == length(BN.β) ||
|
||||||
size(x, ndims(x)-1) == length(β) ||
|
|
||||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
|
γ, β = BN.γ, BN.β
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ones(Int, dims)
|
||||||
affine_shape[end-1] = channels
|
affine_shape[end-1] = channels
|
||||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||||
|
|
||||||
if !training
|
if !BN.active
|
||||||
μ_curr = reshape(μ, affine_shape...)
|
μ = reshape(BN.μ, affine_shape...)
|
||||||
σ_curr = reshape(σ, affine_shape...)
|
σ = reshape(BN.σ, affine_shape...)
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
eps = Flux.data(convert(T, eps))
|
ϵ = data(convert(T, BN.ϵ))
|
||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ_curr = mean(x, axes)
|
μ = mean(x, axes)
|
||||||
σ_curr = sqrt.(mean((x .- μ_curr).^2, axes) .+ eps)
|
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
|
||||||
|
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = Flux.data(convert(T, momentum))
|
mtm = data(convert(T, BN.momentum))
|
||||||
μ .= (1 - mtm) .* μ .+ mtm .* squeeze(Flux.data(μ_curr), (axes...))
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
|
||||||
σ .= (1 - mtm) .* σ .+ mtm .* squeeze(Flux.data(σ_curr), (axes...)) .* m ./ (m - 1)
|
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
||||||
end
|
|
||||||
reshape(γ, affine_shape...) .* ((x .- μ_curr) ./ σ_curr) .+ reshape(β, affine_shape...)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
(BN::BatchNorm)(x) = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = BN.cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
|
let λ = BN.λ
|
||||||
|
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
Flux.treelike(BatchNorm)
|
treelike(BatchNorm)
|
||||||
|
|
||||||
# children(BN::BatchNorm) =
|
|
||||||
# (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
|
|
||||||
#
|
|
||||||
# mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
|
||||||
# BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
|
|
||||||
|
|
||||||
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user