Add working backward pass

This commit is contained in:
Avik Pal 2018-06-20 12:09:54 +05:30
parent bc47d02b3f
commit 185f34d9fe

View File

@ -73,20 +73,20 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
end end
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t, (cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}, Cdouble, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}), Cdouble, Ptr{T}, Ptr{T}),
libcudnn_handle[], BATCHNORM_SPATIAL, libcudnn_handle[], BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)), Ref(T(alpha)), Ref(T(beta)),
xd, x, xd, x,
yd, y, yd, y,
gd, g, b, gd, g, b,
momentum, running_mean, running_var, momentum, running_mean, running_var,
eps, mean, ivar) eps, mean, ivar)
if(cache !== nothing) if(cache !== nothing)
cache.mean = mean cache.mean = mean
@ -94,60 +94,78 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
end end
else else
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t, (Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Cdouble), Cdouble),
libcudnn_handle[], BATCHNORM_SPATIAL, libcudnn_handle[], BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)), Ref(T(alpha)), Ref(T(beta)),
xd, x, xd, x,
yd, y, yd, y,
gd, g, b, gd, g, b,
running_mean, running_var, running_mean, running_var,
eps) eps)
end end
end end
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, function cudnnBNBackward(g, b, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T},
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, running_var::CuArray{T}, momentum;
running_mean::CuArray{T}, running_var::CuArray{T}, training = true, cache = nothing, eps = T(1e-5),
momentum; training = true, alpha = T(1), beta = T(0)) where T<:Union{Float32, Float64}
cache = nothing, eps = T(1e-5), dx = similar(x)
alpha = T(1), beta = T(0), cudnnBNBackward!(g.grad, data(g), b.grad, dx, x, dy, running_mean, running_var, T(momentum),
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64} training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
if(training) dx
end
if cache !== nothing
mean, ivar = cache.mean, cache.ivar function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
cache_verbose && info("mean and ivar are fetched from the cache") dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
else running_mean::CuArray{T}, running_var::CuArray{T},
mean, ivar = C_NULL, C_NULL momentum; training = true,
end cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0),
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
(cudnnHandle_t,cudnnBatchNormMode_t, if(training)
Ptr{T}, Ptr{T}, xd = TensorDesc(x)
Ptr{T}, Ptr{T}, dyd = TensorDesc(dy)
Ptr{Void}, Ptr{T}, dxd = TensorDesc(dx)
Ptr{Void}, Ptr{T}, gd = TensorDesc(T, (1,1,length(g),1))
Ptr{Void}, Ptr{T}, if cache !== nothing
Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{T}, mean, ivar = cache.mean, cache.ivar
Cdouble, Ptr{T}, Ptr{T}), info("mean and ivar are fetched from the cache")
libcudnn_handle[], BATCHNORM_SPATIAL, else
Ref(T(alpha)), Ref(T(beta)), mean, ivar = C_NULL, C_NULL
Ref(T(dalpha)), Ref(T(dbeta)), end
TensorDesc(x), x,
TensorDesc(dy), dy, if(eps < BATCHNORM_MIN_EPS)
TensorDesc(dx), dx, warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
TensorDesc(g), g, dg, db, eps = BATCHNORM_MIN_EPS
eps, mean, ivar) end
else
ivar = 1 ./ sqrt.(running_var .+ eps) @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
dx = dy .* g .* ivar (cudnnHandle_t,cudnnBatchNormMode_t,
dg = sum(dy .* (x .- running_mean) .* ivar, _reddims(dy)) Ptr{T}, Ptr{T},
db = sum(dy, _reddims(dy)) Ptr{T}, Ptr{T},
end Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
libcudnn_handle[], BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)),
xd, x,
dyd, dy,
dxd, dx,
gd, g, dg, db,
eps, mean, ivar)
else
ivar = 1 ./ sqrt.(reshape(running_var, (1, 1, length(running_var), 1)) .+ eps)
dx .= dy .* reshape(g, (1, 1, length(g), 1)) .* ivar
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, (1, 1, length(running_mean), 1))) .* ivar, _reddims(dy)), (1,2,4))
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
end
end end