Forward Pass for BatchNorm Added
This commit is contained in:
parent
8f7ee76752
commit
d4b066fdf9
@ -1,5 +1,7 @@
|
|||||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||||
cudnnDataType, TensorDesc, FilterDesc
|
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
using CuArrays
|
||||||
|
using Flux
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Void}
|
||||||
@ -22,6 +24,92 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||||||
return desc
|
return desc
|
||||||
end
|
end
|
||||||
|
|
||||||
|
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||||
|
CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)},
|
||||||
|
<:CuParam{T,1},<:CuArray{T,1},
|
||||||
|
<:Union{Float32,Float64}}
|
||||||
|
|
||||||
|
CuBatchNorm(chs::Integer, λ = identity;
|
||||||
|
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
|
||||||
|
BatchNorm(λ, param(cu(initβ(Float32,chs))), param(cu(initγ(Float32,chs))),
|
||||||
|
zeros(Float32,chs), ones(Float32,chs), ϵ, momentum, true)
|
||||||
|
|
||||||
|
const BATCHNORM_SPATIAL = 1
|
||||||
|
const BATCHNORM_ACTIVATION = 0
|
||||||
|
const BATCHNORM_MIN_EPS = 1e-5
|
||||||
|
|
||||||
|
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)
|
||||||
|
|
||||||
|
mutable struct bncache
|
||||||
|
mean
|
||||||
|
ivar
|
||||||
|
end
|
||||||
|
|
||||||
|
bncache() = bncache(nothing, nothing)
|
||||||
|
|
||||||
|
(CuBN::CuBatchNorm)(x::CuArray{T}) where T<:Union{Float32, Float64} =
|
||||||
|
CuBN.λ.(cudnnBatchNormalizationForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, eps = CuBN.ϵ, training = CuBN.active))
|
||||||
|
|
||||||
|
function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||||
|
momentum::T; cache = nothing,
|
||||||
|
alpha = T(1), beta = T(0),
|
||||||
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||||
|
y = similar(x)
|
||||||
|
dims = _wsize(x)
|
||||||
|
|
||||||
|
if(eps < BATCHNORM_MIN_EPS)
|
||||||
|
warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||||
|
eps = BATCHNORM_MIN_EPS
|
||||||
|
end
|
||||||
|
|
||||||
|
if(training)
|
||||||
|
|
||||||
|
if(cache !== nothing)
|
||||||
|
mean = cu(zeros(T, dims...))
|
||||||
|
ivar = cu(ones(T, dims...))
|
||||||
|
else
|
||||||
|
mean = C_NULL
|
||||||
|
ivar = C_NULL
|
||||||
|
end
|
||||||
|
|
||||||
|
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
|
(cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void},
|
||||||
|
Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void},
|
||||||
|
Ptr{Void},Ptr{Void},Ptr{Void},
|
||||||
|
Cdouble,Ptr{Void},Ptr{Void},
|
||||||
|
Cdouble,Ptr{Void},Ptr{Void}),
|
||||||
|
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||||
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
|
TensorDesc(x), x,
|
||||||
|
TensorDesc(y), y,
|
||||||
|
TensorDesc(g), g, b,
|
||||||
|
momentum, running_mean, running_var,
|
||||||
|
eps, mean, ivar)
|
||||||
|
|
||||||
|
if(cache !== nothing)
|
||||||
|
cache.mean = mean
|
||||||
|
cache.invvar = ivar
|
||||||
|
end
|
||||||
|
else
|
||||||
|
|
||||||
|
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||||
|
(cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void},
|
||||||
|
Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void},
|
||||||
|
Ptr{Void},Ptr{Void},Ptr{Void},
|
||||||
|
Ptr{Void},Ptr{Void},
|
||||||
|
Cdouble),
|
||||||
|
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||||
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
|
TensorDesc(x), x,
|
||||||
|
TensorDesc(y), y,
|
||||||
|
TensorDesc(g), g, b,
|
||||||
|
running_mean, running_var,
|
||||||
|
eps)
|
||||||
|
end
|
||||||
|
y
|
||||||
|
end
|
||||||
|
|
||||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||||
const LSTM = 2 # LSTM with no peephole connections
|
const LSTM = 2 # LSTM with no peephole connections
|
||||||
|
Loading…
Reference in New Issue
Block a user