remove old-style definition and test
This commit is contained in:
parent
9f9803eec6
commit
77178b7d67
@ -312,9 +312,6 @@ end
|
|||||||
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x)
|
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x)
|
||||||
@grad diagm(x::Pair) = diagm(x[1] => data(x[2])), Δ -> (diag(Δ, x[1]),)
|
@grad diagm(x::Pair) = diagm(x[1] => data(x[2])), Δ -> (diag(Δ, x[1]),)
|
||||||
|
|
||||||
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
|
|
||||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
|
||||||
|
|
||||||
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||||
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
||||||
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
||||||
|
@ -3,7 +3,7 @@ using Flux.Tracker, Test, NNlib
|
|||||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||||
using NNlib: conv, depthwiseconv
|
using NNlib: conv, depthwiseconv
|
||||||
using Printf: @sprintf
|
using Printf: @sprintf
|
||||||
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
|
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||||
using Statistics: mean, std
|
using Statistics: mean, std
|
||||||
using Random
|
using Random
|
||||||
# using StatsBase
|
# using StatsBase
|
||||||
@ -127,7 +127,7 @@ end
|
|||||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||||
|
|
||||||
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
@test gradtest(x -> diagm(0 => x), rand(3))
|
||||||
|
|
||||||
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||||
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||||
|
Loading…
Reference in New Issue
Block a user