Adding untested backward pass code
This commit is contained in:
parent
a83e5d696d
commit
f12e367cab
@ -47,15 +47,25 @@ end
|
|||||||
|
|
||||||
bncache() = bncache(nothing, nothing)
|
bncache() = bncache(nothing, nothing)
|
||||||
|
|
||||||
(CuBN::CuBatchNorm)(x::CuArray{T}) where T<:Union{Float32, Float64} =
|
(CuBN::CuBatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} =
|
||||||
CuBN.λ.(cudnnBatchNormalizationForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, eps = CuBN.ϵ, training = CuBN.active))
|
CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active))
|
||||||
|
|
||||||
function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
function cudnnBNForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||||
momentum::T; cache = nothing,
|
momentum::T; cache = nothing,
|
||||||
alpha = T(1), beta = T(0),
|
alpha = T(1), beta = T(0),
|
||||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||||
y = similar(x)
|
y = similar(x)
|
||||||
|
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache
|
||||||
|
alpha = alpha, beta = beta, eps = eps, training = training)
|
||||||
|
y
|
||||||
|
end
|
||||||
|
|
||||||
|
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||||
|
momentum::T; cache = nothing,
|
||||||
|
alpha = T(1), beta = T(0),
|
||||||
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||||
dims = _wsize(x)
|
dims = _wsize(x)
|
||||||
|
|
||||||
if(eps < BATCHNORM_MIN_EPS)
|
if(eps < BATCHNORM_MIN_EPS)
|
||||||
@ -74,11 +84,13 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
end
|
end
|
||||||
|
|
||||||
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
(cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void},
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Void},Ptr{Void},Ptr{Void},
|
Ptr{Void}, Ptr{T},
|
||||||
Cdouble,Ptr{Void},Ptr{Void},
|
Ptr{Void}, Ptr{T},
|
||||||
Cdouble,Ptr{Void},Ptr{Void}),
|
Ptr{Void}, Ptr{T}, Ptr{T},
|
||||||
|
Cdouble, Ptr{T}, Ptr{T},
|
||||||
|
Cdouble, Ptr{T}, Ptr{T}),
|
||||||
libcudnn_handle[], BATCHNORM_SPATIAL,
|
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
TensorDesc(x), x,
|
TensorDesc(x), x,
|
||||||
@ -94,10 +106,12 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
else
|
else
|
||||||
|
|
||||||
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void},
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Void},Ptr{Void},Ptr{Void},
|
Ptr{Void}, Ptr{T},
|
||||||
Ptr{Void},Ptr{Void},
|
Ptr{Void}, Ptr{T},
|
||||||
|
Ptr{Void}, Ptr{T}, Ptr{T},
|
||||||
|
Ptr{T}, Ptr{T},
|
||||||
Cdouble),
|
Cdouble),
|
||||||
libcudnn_handle[], BATCHNORM_SPATIAL,
|
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
@ -107,7 +121,47 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
running_mean, running_var,
|
running_mean, running_var,
|
||||||
eps)
|
eps)
|
||||||
end
|
end
|
||||||
y
|
end
|
||||||
|
|
||||||
|
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||||
|
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T}
|
||||||
|
momentum; training = true,
|
||||||
|
cache = nothing, eps = T(1e-5),
|
||||||
|
alpha = T(1), beta = T(0),
|
||||||
|
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
|
||||||
|
if(training)
|
||||||
|
|
||||||
|
if cache !== nothing
|
||||||
|
mean, ivar = cache.mean, cache.ivar
|
||||||
|
cache_verbose && info("mean and ivar are fetched from the cache")
|
||||||
|
else
|
||||||
|
mean, ivar = C_NULL, C_NULL
|
||||||
|
end
|
||||||
|
|
||||||
|
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
|
||||||
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
|
Ptr{T}, Ptr{T},
|
||||||
|
Ptr{T}, Ptr{T},
|
||||||
|
Ptr{Void}, Ptr{T},
|
||||||
|
Ptr{Void}, Ptr{T},
|
||||||
|
Ptr{Void}, Ptr{T},
|
||||||
|
Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{T},
|
||||||
|
Cdouble, Ptr{T}, Ptr{T}),
|
||||||
|
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||||
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
|
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||||
|
TensorDesc(x), x,
|
||||||
|
TensorDesc(dy), dy,
|
||||||
|
TensorDesc(dx), dx,
|
||||||
|
TensorDesc(g), g, dg, db,
|
||||||
|
eps, mean, ivar)
|
||||||
|
else
|
||||||
|
ivar = 1 ./ sqrt.(running_var .+ eps)
|
||||||
|
dx = dy .* g .* ivar
|
||||||
|
dg = sum(dy .* (x .- running_mean) .* ivar, _reddims(dy))
|
||||||
|
db = sum(dy, _reddims(dy))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||||
|
Loading…
Reference in New Issue
Block a user