Add working backward pass
This commit is contained in:
parent
bc47d02b3f
commit
185f34d9fe
@ -111,6 +111,16 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
||||
end
|
||||
end
|
||||
|
||||
function cudnnBNBackward(g, b, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum;
|
||||
training = true, cache = nothing, eps = T(1e-5),
|
||||
alpha = T(1), beta = T(0)) where T<:Union{Float32, Float64}
|
||||
dx = similar(x)
|
||||
cudnnBNBackward!(g.grad, data(g), b.grad, dx, x, dy, running_mean, running_var, T(momentum),
|
||||
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
|
||||
dx
|
||||
end
|
||||
|
||||
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||
@ -119,14 +129,22 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||
alpha = T(1), beta = T(0),
|
||||
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
|
||||
if(training)
|
||||
|
||||
xd = TensorDesc(x)
|
||||
dyd = TensorDesc(dy)
|
||||
dxd = TensorDesc(dx)
|
||||
gd = TensorDesc(T, (1,1,length(g),1))
|
||||
if cache !== nothing
|
||||
mean, ivar = cache.mean, cache.ivar
|
||||
cache_verbose && info("mean and ivar are fetched from the cache")
|
||||
info("mean and ivar are fetched from the cache")
|
||||
else
|
||||
mean, ivar = C_NULL, C_NULL
|
||||
end
|
||||
|
||||
if(eps < BATCHNORM_MIN_EPS)
|
||||
warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||
eps = BATCHNORM_MIN_EPS
|
||||
end
|
||||
|
||||
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
|
||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
@ -139,15 +157,15 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||
TensorDesc(x), x,
|
||||
TensorDesc(dy), dy,
|
||||
TensorDesc(dx), dx,
|
||||
TensorDesc(g), g, dg, db,
|
||||
xd, x,
|
||||
dyd, dy,
|
||||
dxd, dx,
|
||||
gd, g, dg, db,
|
||||
eps, mean, ivar)
|
||||
else
|
||||
ivar = 1 ./ sqrt.(running_var .+ eps)
|
||||
dx = dy .* g .* ivar
|
||||
dg = sum(dy .* (x .- running_mean) .* ivar, _reddims(dy))
|
||||
db = sum(dy, _reddims(dy))
|
||||
ivar = 1 ./ sqrt.(reshape(running_var, (1, 1, length(running_var), 1)) .+ eps)
|
||||
dx .= dy .* reshape(g, (1, 1, length(g), 1)) .* ivar
|
||||
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, (1, 1, length(running_mean), 1))) .* ivar, _reddims(dy)), (1,2,4))
|
||||
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user