Add working backward pass
This commit is contained in:
parent
bc47d02b3f
commit
185f34d9fe
@ -111,6 +111,16 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function cudnnBNBackward(g, b, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T},
|
||||||
|
running_var::CuArray{T}, momentum;
|
||||||
|
training = true, cache = nothing, eps = T(1e-5),
|
||||||
|
alpha = T(1), beta = T(0)) where T<:Union{Float32, Float64}
|
||||||
|
dx = similar(x)
|
||||||
|
cudnnBNBackward!(g.grad, data(g), b.grad, dx, x, dy, running_mean, running_var, T(momentum),
|
||||||
|
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
|
||||||
|
dx
|
||||||
|
end
|
||||||
|
|
||||||
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||||
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||||
@ -119,14 +129,22 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
alpha = T(1), beta = T(0),
|
alpha = T(1), beta = T(0),
|
||||||
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
|
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
|
||||||
if(training)
|
if(training)
|
||||||
|
xd = TensorDesc(x)
|
||||||
|
dyd = TensorDesc(dy)
|
||||||
|
dxd = TensorDesc(dx)
|
||||||
|
gd = TensorDesc(T, (1,1,length(g),1))
|
||||||
if cache !== nothing
|
if cache !== nothing
|
||||||
mean, ivar = cache.mean, cache.ivar
|
mean, ivar = cache.mean, cache.ivar
|
||||||
cache_verbose && info("mean and ivar are fetched from the cache")
|
info("mean and ivar are fetched from the cache")
|
||||||
else
|
else
|
||||||
mean, ivar = C_NULL, C_NULL
|
mean, ivar = C_NULL, C_NULL
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if(eps < BATCHNORM_MIN_EPS)
|
||||||
|
warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||||
|
eps = BATCHNORM_MIN_EPS
|
||||||
|
end
|
||||||
|
|
||||||
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
|
||||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
@ -139,15 +157,15 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
libcudnn_handle[], BATCHNORM_SPATIAL,
|
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
Ref(T(dalpha)), Ref(T(dbeta)),
|
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||||
TensorDesc(x), x,
|
xd, x,
|
||||||
TensorDesc(dy), dy,
|
dyd, dy,
|
||||||
TensorDesc(dx), dx,
|
dxd, dx,
|
||||||
TensorDesc(g), g, dg, db,
|
gd, g, dg, db,
|
||||||
eps, mean, ivar)
|
eps, mean, ivar)
|
||||||
else
|
else
|
||||||
ivar = 1 ./ sqrt.(running_var .+ eps)
|
ivar = 1 ./ sqrt.(reshape(running_var, (1, 1, length(running_var), 1)) .+ eps)
|
||||||
dx = dy .* g .* ivar
|
dx .= dy .* reshape(g, (1, 1, length(g), 1)) .* ivar
|
||||||
dg = sum(dy .* (x .- running_mean) .* ivar, _reddims(dy))
|
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, (1, 1, length(running_mean), 1))) .* ivar, _reddims(dy)), (1,2,4))
|
||||||
db = sum(dy, _reddims(dy))
|
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user