Compare commits

...

4 Commits

Author SHA1 Message Date
Keno Fischer 234b757da9 Treat NamedTuple like Tuple for treelike purposes 2019-02-06 11:08:06 -05:00
Tim Besard f32c34247a Adapt to the new CUDAdrv.CuPtr pointer type. 2019-02-06 10:42:31 -05:00
Keno Fischer 16369769ab Handle various cases of multiplying transpose-wrapped matrices
See test cases. I hit these while taking third-order derivatives of
matrix multiplies (whose gradient definitions use transpose).
2019-02-06 10:41:32 -05:00
Keno Fischer d36d7879df Add higher order autodiff for getindex
This fixed higher order autodiff or getindex (and by extension vcat, since
it's backward pass uses getindex). This makes tracker able to handle
the third order derivatives needed to implement [1].

[1] Physics-Informed Generative Adversarial Networks for Stochastic Differential Equations. https://arxiv.org/abs/1811.02033
2019-02-06 10:41:23 -05:00
8 changed files with 121 additions and 49 deletions

View File

@ -1,5 +1,6 @@
module CUDA
import CUDAdrv: CuPtr, CU_NULL
using ..CuArrays
using Pkg.TOML

View File

@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong),
desc,handle(),ρ,states,length(states),seed)
finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
mean = zeros(CuArray{T}, dims...)
ivar = ones(CuArray{T}, dims...)
else
mean = C_NULL
ivar = C_NULL
mean = CU_NULL
ivar = CU_NULL
end
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
Cdouble, CuPtr{T}, CuPtr{T},
Cdouble, CuPtr{T}, CuPtr{T}),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
xd, x,
@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, CuPtr{T},
CuPtr{T}, CuPtr{T},
Cdouble),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
mean, ivar = cache.mean, cache.ivar
info("mean and ivar are fetched from the cache")
else
mean, ivar = C_NULL, C_NULL
mean, ivar = CU_NULL, CU_NULL
end
if eps < BATCHNORM_MIN_EPS
@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
(cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
Cdouble, Ptr{T}, Ptr{T}),
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T},
Cdouble, CuPtr{T}, CuPtr{T}),
handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)),

View File

@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t),
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T},
CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace))
else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve))
@ -121,7 +121,7 @@ end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, C_NULL
hDesc(h::Nothing) = C_NULL, CU_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing},
CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end
@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Nothing}}, Ptr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y
Ptr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw
Ptr{Nothing}, Csize_t), #rs
Ptr{Ptr{Nothing}}, CuPtr{T}, #x
Ptr{Nothing}, CuPtr{T}, #hx
Ptr{Ptr{Nothing}}, CuPtr{T}, #y
CuPtr{Nothing}, Csize_t, #ws
Ptr{Nothing}, CuPtr{T}, #dw
CuPtr{Nothing}, Csize_t), #rs
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve))
end

View File

@ -1,5 +1,5 @@
init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
init_grad(x) = zero_or_nothing(x)
zero_grad!(x) = zero_or_nothing(x)
zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args)
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
# This is correct, because the `.grad` field of tracked is constrained
# by a type parameter constructed from the (concrete) forward value and
# thus for `nothing` to show up here, we're guaranteed that the value
# was `nothing` during the forward pass (and thus we don't care about its
# derivatives).
accum!(x::Nothing, Δ) = x
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end

View File

@ -95,12 +95,27 @@ end
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], function (Δ)
Δ′ = zero(xs)
Δ′[i...] = data(Δ)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
struct ∇getindex{T,S}
xs::T
i::S
end
function (g::∇getindex)(Δ)
Δ′ = zero(g.xs)
Δ′[g.i...] = Δ
(Δ′, map(_->nothing, g.i)...)
end
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
@grad function (g::∇getindex)(Δ)
z, back = g(data(Δ)), function(Δ′′)
(Δ′′[1][g.i...],)
end
z, back
end
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], ∇getindex(xs, i)
end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@ -361,6 +376,9 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
@grad a::TrackedMatrix{<:Any, <:Transpose} * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (transpose(b * transpose(Δ)), transpose(a) * Δ)
# NNlib
using NNlib
@ -422,12 +440,17 @@ end
using ForwardDiff: Dual, partials, value
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
trim(x, Δ) = ndims(Δ) == ndims(x) ? Δ : reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
function unbroadcast(x::AbstractArray, Δ)
if size(x) == size(Δ)
return Δ
elseif length(x) == length(Δ)
return trim(x, Δ)
else
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
end
end
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue, _) = nothing

View File

@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
zero_or_nothing(x) = zero(x)
zero_or_nothing(x::Nothing) = nothing
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
@ -132,11 +134,18 @@ end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
@grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
@grad function iterate(xs::TrackedTuple, i=1)
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
end
# Array collection
function collect(xs)

View File

@ -5,7 +5,9 @@ children(x) = ()
mapchildren(f, x) = x
children(x::Tuple) = x
children(x::NamedTuple) = x
mapchildren(f, x::Tuple) = map(f, x)
mapchildren(f, x::NamedTuple) = map(f, x)
function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin

View File

@ -5,6 +5,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std
using ForwardDiff
using Random
# using StatsBase
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
# Nested AD for getindex
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
sum(gradient(x; nest=true) do x
sum(gradient(x; nest=true) do x
sum(x[1:2])^4
end[1])
end[1])
end[1]
# We compare to ForwardDiff, since the high order derivative is not
# numerically stable under finite differencing.
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
sum(ForwardDiff.gradient(x) do x
sum(ForwardDiff.gradient(x) do x
sum(x[1:2])^4
end)
end)
end
@test grad_tracker grad_forward [288.0, 288.0, 0.0]
end
function promotiontest(f, A, B, C)
@ -336,4 +356,14 @@ end
@test back([1, 1]) == (32,)
end
@testset "transpose" begin
let f = (x,a,b)->(x = transpose(x); x * a + x * b),
g = x->(a = transpose(x); b = transpose(a); b * [1.0 1.0; 2.0 3.0] + a * [1.0 1.0; 2.0 3.0])
@test gradient(x->sum(f(x, [1.0; 1.0], [1.0; 1.0])), [1.0 1.0; 1.0 1.0])[1] ==
[2.0 2.0; 2.0 2.0]
@test gradient(x->sum(g(x)), [1.0 1.0; 1.0 1.0])[1] ==
[4.0 7.0; 7.0 10.0]
end
end
end #testset