Adding untested backward pass code

This commit is contained in:
Avik Pal 2018-06-12 18:26:09 +05:30
parent a83e5d696d
commit f12e367cab

View File

@ -47,15 +47,25 @@ end
bncache() = bncache(nothing, nothing)
(CuBN::CuBatchNorm)(x::CuArray{T}) where T<:Union{Float32, Float64} =
CuBN.λ.(cudnnBatchNormalizationForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, eps = CuBN.ϵ, training = CuBN.active))
(CuBN::CuBatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} =
CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active))
function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum::T; cache = nothing,
alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
y = similar(x)
function cudnnBNForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum::T; cache = nothing,
alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
y = similar(x)
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache
alpha = alpha, beta = beta, eps = eps, training = training)
y
end
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum::T; cache = nothing,
alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
dims = _wsize(x)
if(eps < BATCHNORM_MIN_EPS)
@ -74,11 +84,13 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray
end
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void},
Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void},
Ptr{Void},Ptr{Void},Ptr{Void},
Cdouble,Ptr{Void},Ptr{Void},
Cdouble,Ptr{Void},Ptr{Void}),
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
libcudnn_handle[], BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
TensorDesc(x), x,
@ -94,10 +106,12 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray
else
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void},
Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void},
Ptr{Void},Ptr{Void},Ptr{Void},
Ptr{Void},Ptr{Void},
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Cdouble),
libcudnn_handle[], BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
@ -107,7 +121,47 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray
running_mean, running_var,
eps)
end
y
end
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T}
momentum; training = true,
cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
if(training)
if cache !== nothing
mean, ivar = cache.mean, cache.ivar
cache_verbose && info("mean and ivar are fetched from the cache")
else
mean, ivar = C_NULL, C_NULL
end
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
libcudnn_handle[], BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)),
TensorDesc(x), x,
TensorDesc(dy), dy,
TensorDesc(dx), dx,
TensorDesc(g), g, dg, db,
eps, mean, ivar)
else
ivar = 1 ./ sqrt.(running_var .+ eps)
dx = dy .* g .* ivar
dg = sum(dy .* (x .- running_mean) .* ivar, _reddims(dy))
db = sum(dy, _reddims(dy))
end
end
const RNN_RELU = 0 # Stock RNN with ReLu activation