Compare commits
4 Commits
Author | SHA1 | Date |
---|---|---|
![]() |
234b757da9 | |
![]() |
f32c34247a | |
![]() |
16369769ab | |
![]() |
d36d7879df |
|
@ -1,5 +1,6 @@
|
|||
module CUDA
|
||||
|
||||
import CUDAdrv: CuPtr, CU_NULL
|
||||
using ..CuArrays
|
||||
using Pkg.TOML
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
|
||||
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
||||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong),
|
||||
desc,handle(),ρ,states,length(states),seed)
|
||||
finalizer(desc) do x
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
|
@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||
mean = zeros(CuArray{T}, dims...)
|
||||
ivar = ones(CuArray{T}, dims...)
|
||||
else
|
||||
mean = C_NULL
|
||||
ivar = C_NULL
|
||||
mean = CU_NULL
|
||||
ivar = CU_NULL
|
||||
end
|
||||
|
||||
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
||||
Cdouble, Ptr{T}, Ptr{T},
|
||||
Cdouble, Ptr{T}, Ptr{T}),
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
|
||||
Cdouble, CuPtr{T}, CuPtr{T},
|
||||
Cdouble, CuPtr{T}, CuPtr{T}),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
xd, x,
|
||||
|
@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
|
||||
CuPtr{T}, CuPtr{T},
|
||||
Cdouble),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
|
@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||
mean, ivar = cache.mean, cache.ivar
|
||||
info("mean and ivar are fetched from the cache")
|
||||
else
|
||||
mean, ivar = C_NULL, C_NULL
|
||||
mean, ivar = CU_NULL, CU_NULL
|
||||
end
|
||||
|
||||
if eps < BATCHNORM_MIN_EPS
|
||||
|
@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{T}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
|
||||
Cdouble, Ptr{T}, Ptr{T}),
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T},
|
||||
Cdouble, CuPtr{T}, CuPtr{T}),
|
||||
handle(), BATCHNORM_SPATIAL,
|
||||
Ref(T(alpha)), Ref(T(beta)),
|
||||
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||
|
|
|
@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
|
|||
if reserve == nothing
|
||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t),
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T},
|
||||
CuPtr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace))
|
||||
else
|
||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, length(workspace), reserve, length(reserve))
|
||||
|
@ -121,7 +121,7 @@ end
|
|||
|
||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||
|
||||
hDesc(h::Nothing) = C_NULL, C_NULL
|
||||
hDesc(h::Nothing) = C_NULL, CU_NULL
|
||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||
function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
|
@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
|
|||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
||||
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing},
|
||||
CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
|
||||
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
|
||||
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||
end
|
||||
|
@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
|
|||
workspace, reserve) where T
|
||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
||||
Ptr{Nothing}, Ptr{T}, #hx
|
||||
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
||||
Ptr{Nothing}, Csize_t, #ws
|
||||
Ptr{Nothing}, Ptr{T}, #dw
|
||||
Ptr{Nothing}, Csize_t), #rs
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, #x
|
||||
Ptr{Nothing}, CuPtr{T}, #hx
|
||||
Ptr{Ptr{Nothing}}, CuPtr{T}, #y
|
||||
CuPtr{Nothing}, Csize_t, #ws
|
||||
Ptr{Nothing}, CuPtr{T}, #dw
|
||||
CuPtr{Nothing}, Csize_t), #rs
|
||||
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||
end
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
init_grad(x) = zero_or_nothing(x)
|
||||
zero_grad!(x) = zero_or_nothing(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
|||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
# This is correct, because the `.grad` field of tracked is constrained
|
||||
# by a type parameter constructed from the (concrete) forward value and
|
||||
# thus for `nothing` to show up here, we're guaranteed that the value
|
||||
# was `nothing` during the forward pass (and thus we don't care about its
|
||||
# derivatives).
|
||||
accum!(x::Nothing, Δ) = x
|
||||
|
||||
function back(x::Tracked, Δ, once)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
|
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
|
|
@ -95,12 +95,27 @@ end
|
|||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
Δ′[i...] = data(Δ)
|
||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
||||
struct ∇getindex{T,S}
|
||||
xs::T
|
||||
i::S
|
||||
end
|
||||
|
||||
function (g::∇getindex)(Δ)
|
||||
Δ′ = zero(g.xs)
|
||||
Δ′[g.i...] = Δ
|
||||
(Δ′, map(_->nothing, g.i)...)
|
||||
end
|
||||
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
|
||||
|
||||
@grad function (g::∇getindex)(Δ)
|
||||
z, back = g(data(Δ)), function(Δ′′)
|
||||
(Δ′′[1][g.i...],)
|
||||
end
|
||||
z, back
|
||||
end
|
||||
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], ∇getindex(xs, i)
|
||||
end
|
||||
|
||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||
|
@ -361,6 +376,9 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
|
|||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||
|
||||
@grad a::TrackedMatrix{<:Any, <:Transpose} * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (transpose(b * transpose(Δ)), transpose(a) * Δ)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
|
@ -422,12 +440,17 @@ end
|
|||
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
trim(x, Δ) = ndims(Δ) == ndims(x) ? Δ : reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
function unbroadcast(x::AbstractArray, Δ)
|
||||
if size(x) == size(Δ)
|
||||
return Δ
|
||||
elseif length(x) == length(Δ)
|
||||
return trim(x, Δ)
|
||||
else
|
||||
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
end
|
||||
end
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue, _) = nothing
|
||||
|
|
|
@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
|||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
||||
zero_or_nothing(x) = zero(x)
|
||||
zero_or_nothing(x::Nothing) = nothing
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
|
@ -132,11 +134,18 @@ end
|
|||
Base.length(x::TrackedTuple) = length(data(x))
|
||||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
|
||||
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
|
||||
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
|
||||
|
||||
@grad function getindex(xs::TrackedTuple, i)
|
||||
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
@grad function iterate(xs::TrackedTuple, i=1)
|
||||
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
|
|
|
@ -5,7 +5,9 @@ children(x) = ()
|
|||
mapchildren(f, x) = x
|
||||
|
||||
children(x::Tuple) = x
|
||||
children(x::NamedTuple) = x
|
||||
mapchildren(f, x::Tuple) = map(f, x)
|
||||
mapchildren(f, x::NamedTuple) = map(f, x)
|
||||
|
||||
function treelike(m::Module, T, fs = fieldnames(T))
|
||||
@eval m begin
|
||||
|
|
|
@ -5,6 +5,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv
|
|||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
using ForwardDiff
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
||||
|
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||
|
||||
@testset "indexing & slicing" begin
|
||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||
|
||||
# Nested AD for getindex
|
||||
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
|
||||
sum(gradient(x; nest=true) do x
|
||||
sum(gradient(x; nest=true) do x
|
||||
sum(x[1:2])^4
|
||||
end[1])
|
||||
end[1])
|
||||
end[1]
|
||||
# We compare to ForwardDiff, since the high order derivative is not
|
||||
# numerically stable under finite differencing.
|
||||
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
|
||||
sum(ForwardDiff.gradient(x) do x
|
||||
sum(ForwardDiff.gradient(x) do x
|
||||
sum(x[1:2])^4
|
||||
end)
|
||||
end)
|
||||
end
|
||||
@test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0]
|
||||
end
|
||||
|
||||
function promotiontest(f, A, B, C)
|
||||
|
@ -336,4 +356,14 @@ end
|
|||
@test back([1, 1]) == (32,)
|
||||
end
|
||||
|
||||
@testset "transpose" begin
|
||||
let f = (x,a,b)->(x = transpose(x); x * a + x * b),
|
||||
g = x->(a = transpose(x); b = transpose(a); b * [1.0 1.0; 2.0 3.0] + a * [1.0 1.0; 2.0 3.0])
|
||||
@test gradient(x->sum(f(x, [1.0; 1.0], [1.0; 1.0])), [1.0 1.0; 1.0 1.0])[1] ==
|
||||
[2.0 2.0; 2.0 2.0]
|
||||
@test gradient(x->sum(g(x)), [1.0 1.0; 1.0 1.0])[1] ==
|
||||
[4.0 7.0; 7.0 10.0]
|
||||
end
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue