Add working backward pass

This commit is contained in:
Avik Pal 2018-06-20 12:09:54 +05:30
parent bc47d02b3f
commit 185f34d9fe

View File

@ -111,6 +111,16 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
end
end
function cudnnBNBackward(g, b, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum;
training = true, cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0)) where T<:Union{Float32, Float64}
dx = similar(x)
cudnnBNBackward!(g.grad, data(g), b.grad, dx, x, dy, running_mean, running_var, T(momentum),
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
dx
end
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
@ -119,14 +129,22 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
if(training)
xd = TensorDesc(x)
dyd = TensorDesc(dy)
dxd = TensorDesc(dx)
gd = TensorDesc(T, (1,1,length(g),1))
if cache !== nothing
mean, ivar = cache.mean, cache.ivar
cache_verbose && info("mean and ivar are fetched from the cache")
info("mean and ivar are fetched from the cache")
else
mean, ivar = C_NULL, C_NULL
end
if(eps < BATCHNORM_MIN_EPS)
warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
eps = BATCHNORM_MIN_EPS
end
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
@ -139,15 +157,15 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
libcudnn_handle[], BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)),
TensorDesc(x), x,
TensorDesc(dy), dy,
TensorDesc(dx), dx,
TensorDesc(g), g, dg, db,
xd, x,
dyd, dy,
dxd, dx,
gd, g, dg, db,
eps, mean, ivar)
else
ivar = 1 ./ sqrt.(running_var .+ eps)
dx = dy .* g .* ivar
dg = sum(dy .* (x .- running_mean) .* ivar, _reddims(dy))
db = sum(dy, _reddims(dy))
ivar = 1 ./ sqrt.(reshape(running_var, (1, 1, length(running_var), 1)) .+ eps)
dx .= dy .* reshape(g, (1, 1, length(g), 1)) .* ivar
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, (1, 1, length(running_mean), 1))) .* ivar, _reddims(dy)), (1,2,4))
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
end
end