Compare commits

...

4 Commits

Author SHA1 Message Date
Keno Fischer
234b757da9 Treat NamedTuple like Tuple for treelike purposes 2019-02-06 11:08:06 -05:00
Tim Besard
f32c34247a Adapt to the new CUDAdrv.CuPtr pointer type. 2019-02-06 10:42:31 -05:00
Keno Fischer
16369769ab Handle various cases of multiplying transpose-wrapped matrices
See test cases. I hit these while taking third-order derivatives of
matrix multiplies (whose gradient definitions use transpose).
2019-02-06 10:41:32 -05:00
Keno Fischer
d36d7879df Add higher order autodiff for getindex
This fixed higher order autodiff or getindex (and by extension vcat, since
it's backward pass uses getindex). This makes tracker able to handle
the third order derivatives needed to implement [1].

[1] Physics-Informed Generative Adversarial Networks for Stochastic Differential Equations. https://arxiv.org/abs/1811.02033
2019-02-06 10:41:23 -05:00
8 changed files with 121 additions and 49 deletions

View File

@ -1,5 +1,6 @@
module CUDA module CUDA
import CUDAdrv: CuPtr, CU_NULL
using ..CuArrays using ..CuArrays
using Pkg.TOML using Pkg.TOML

View File

@ -17,7 +17,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s) @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0? states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states) desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong), @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,CuPtr{Nothing},Csize_t,Culonglong),
desc,handle(),ρ,states,length(states),seed) desc,handle(),ρ,states,length(states),seed)
finalizer(desc) do x finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
@ -79,18 +79,18 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
mean = zeros(CuArray{T}, dims...) mean = zeros(CuArray{T}, dims...)
ivar = ones(CuArray{T}, dims...) ivar = ones(CuArray{T}, dims...)
else else
mean = C_NULL mean = CU_NULL
ivar = C_NULL ivar = CU_NULL
end end
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t, (cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{Nothing}, CuPtr{T}, CuPtr{T},
Cdouble, Ptr{T}, Ptr{T}, Cdouble, CuPtr{T}, CuPtr{T},
Cdouble, Ptr{T}, Ptr{T}), Cdouble, CuPtr{T}, CuPtr{T}),
handle(), BATCHNORM_SPATIAL, handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)), Ref(T(alpha)), Ref(T(beta)),
xd, x, xd, x,
@ -107,10 +107,10 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t, (Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{Nothing}, CuPtr{T}, CuPtr{T},
Ptr{T}, Ptr{T}, CuPtr{T}, CuPtr{T},
Cdouble), Cdouble),
handle(), BATCHNORM_SPATIAL, handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)), Ref(T(alpha)), Ref(T(beta)),
@ -159,7 +159,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
mean, ivar = cache.mean, cache.ivar mean, ivar = cache.mean, cache.ivar
info("mean and ivar are fetched from the cache") info("mean and ivar are fetched from the cache")
else else
mean, ivar = C_NULL, C_NULL mean, ivar = CU_NULL, CU_NULL
end end
if eps < BATCHNORM_MIN_EPS if eps < BATCHNORM_MIN_EPS
@ -170,11 +170,11 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
(cudnnHandle_t,cudnnBatchNormMode_t, (cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T}, Ptr{Nothing}, CuPtr{T}, CuPtr{T}, CuPtr{T},
Cdouble, Ptr{T}, Ptr{T}), Cdouble, CuPtr{T}, CuPtr{T}),
handle(), BATCHNORM_SPATIAL, handle(), BATCHNORM_SPATIAL,
Ref(T(alpha)), Ref(T(beta)), Ref(T(alpha)), Ref(T(beta)),
Ref(T(dalpha)), Ref(T(dbeta)), Ref(T(dalpha)), Ref(T(dbeta)),

View File

@ -101,18 +101,18 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
if reserve == nothing if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Csize_t), CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen, handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace)) workspace, length(workspace))
else else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen, handle(), rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve)) workspace, length(workspace), reserve, length(reserve))
@ -121,7 +121,7 @@ end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Nothing) = C_NULL, C_NULL hDesc(h::Nothing) = C_NULL, CU_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray) function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
@ -169,10 +169,10 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, CuPtr{T}, Ptr{Ptr{Nothing}}, CuPtr{T}, Ptr{Nothing}, CuPtr{T}, Ptr{Nothing}, CuPtr{T},
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), CuPtr{Nothing}, Csize_t, CuPtr{Nothing}, Csize_t),
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end end
@ -199,12 +199,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
workspace, reserve) where T workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Nothing}}, Ptr{T}, #x Ptr{Ptr{Nothing}}, CuPtr{T}, #x
Ptr{Nothing}, Ptr{T}, #hx Ptr{Nothing}, CuPtr{T}, #hx
Ptr{Ptr{Nothing}}, Ptr{T}, #y Ptr{Ptr{Nothing}}, CuPtr{T}, #y
Ptr{Nothing}, Csize_t, #ws CuPtr{Nothing}, Csize_t, #ws
Ptr{Nothing}, Ptr{T}, #dw Ptr{Nothing}, CuPtr{T}, #dw
Ptr{Nothing}, Csize_t), #rs CuPtr{Nothing}, Csize_t), #rs
handle(), rnn, seqlen, xd, x, hd, h, yd, y, handle(), rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve)) workspace, length(workspace), dwd, dw, reserve, length(reserve))
end end

View File

@ -1,5 +1,5 @@
init_grad(x) = zero(x) init_grad(x) = zero_or_nothing(x)
zero_grad!(x) = zero(x) zero_grad!(x) = zero_or_nothing(x)
zero_grad!(x::AbstractArray) = (x .= 0) zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args) scan(c::Call) = foreach(scan, c.args)
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
# This is correct, because the `.grad` field of tracked is constrained
# by a type parameter constructed from the (concrete) forward value and
# thus for `nothing` to show up here, we're guaranteed that the value
# was `nothing` during the forward pass (and thus we don't care about its
# derivatives).
accum!(x::Nothing, Δ) = x
function back(x::Tracked, Δ, once) function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return) x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1 ref = x.ref -= 1
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ) function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || ((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))") error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end end

View File

@ -95,12 +95,27 @@ end
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
@grad function getindex(xs::AbstractArray, i...) struct ∇getindex{T,S}
data(xs)[i...], function (Δ) xs::T
Δ′ = zero(xs) i::S
Δ′[i...] = data(Δ) end
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
function (g::∇getindex)(Δ)
Δ′ = zero(g.xs)
Δ′[g.i...] = Δ
(Δ′, map(_->nothing, g.i)...)
end
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
@grad function (g::∇getindex)(Δ)
z, back = g(data(Δ)), function(Δ′′)
(Δ′′[1][g.i...],)
end end
z, back
end
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], ∇getindex(xs, i)
end end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@ -361,6 +376,9 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat = @grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ) data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
@grad a::TrackedMatrix{<:Any, <:Transpose} * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (transpose(b * transpose(Δ)), transpose(a) * Δ)
# NNlib # NNlib
using NNlib using NNlib
@ -422,12 +440,17 @@ end
using ForwardDiff: Dual, partials, value using ForwardDiff: Dual, partials, value
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) trim(x, Δ) = ndims(Δ) == ndims(x) ? Δ : reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) = function unbroadcast(x::AbstractArray, Δ)
size(x) == size(Δ) ? Δ : if size(x) == size(Δ)
length(x) == length(Δ) ? trim(x, Δ) : return Δ
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) elseif length(x) == length(Δ)
return trim(x, Δ)
else
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
end
end
unbroadcast(x::Number, Δ) = sum(Δ) unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue, _) = nothing unbroadcast(x::Base.RefValue, _) = nothing

View File

@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x) init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x) zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs))) zero_or_nothing(x) = zero(x)
zero_or_nothing(x::Nothing) = nothing
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
function Base.show(io::IO, xs::TrackedTuple) function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs)) show(io, data(xs))
@ -132,11 +134,18 @@ end
Base.length(x::TrackedTuple) = length(data(x)) Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
@grad function getindex(xs::TrackedTuple, i) @grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing) data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end end
@grad function iterate(xs::TrackedTuple, i=1)
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
end
# Array collection # Array collection
function collect(xs) function collect(xs)

View File

@ -5,7 +5,9 @@ children(x) = ()
mapchildren(f, x) = x mapchildren(f, x) = x
children(x::Tuple) = x children(x::Tuple) = x
children(x::NamedTuple) = x
mapchildren(f, x::Tuple) = map(f, x) mapchildren(f, x::Tuple) = map(f, x)
mapchildren(f, x::NamedTuple) = map(f, x)
function treelike(m::Module, T, fs = fieldnames(T)) function treelike(m::Module, T, fs = fieldnames(T))
@eval m begin @eval m begin

View File

@ -5,6 +5,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv
using Printf: @sprintf using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std using Statistics: mean, std
using ForwardDiff
using Random using Random
# using StatsBase # using StatsBase
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "indexing & slicing" begin @testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
# Nested AD for getindex
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
sum(gradient(x; nest=true) do x
sum(gradient(x; nest=true) do x
sum(x[1:2])^4
end[1])
end[1])
end[1]
# We compare to ForwardDiff, since the high order derivative is not
# numerically stable under finite differencing.
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
sum(ForwardDiff.gradient(x) do x
sum(ForwardDiff.gradient(x) do x
sum(x[1:2])^4
end)
end)
end
@test grad_tracker grad_forward [288.0, 288.0, 0.0]
end end
function promotiontest(f, A, B, C) function promotiontest(f, A, B, C)
@ -336,4 +356,14 @@ end
@test back([1, 1]) == (32,) @test back([1, 1]) == (32,)
end end
@testset "transpose" begin
let f = (x,a,b)->(x = transpose(x); x * a + x * b),
g = x->(a = transpose(x); b = transpose(a); b * [1.0 1.0; 2.0 3.0] + a * [1.0 1.0; 2.0 3.0])
@test gradient(x->sum(f(x, [1.0; 1.0], [1.0; 1.0])), [1.0 1.0; 1.0 1.0])[1] ==
[2.0 2.0; 2.0 2.0]
@test gradient(x->sum(g(x)), [1.0 1.0; 1.0 1.0])[1] ==
[4.0 7.0; 7.0 10.0]
end
end
end #testset end #testset