diagm
This commit is contained in:
parent
2fec75005d
commit
a4bf5936b0
|
@ -91,6 +91,9 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
|||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = TrackedArray(Call(diagm, x))
|
||||
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
|
|
|
@ -31,6 +31,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@test gradtest(diagm, rand(3))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
|
|
Loading…
Reference in New Issue