Ac_mul_B derivatives
This commit is contained in:
parent
d4229c4815
commit
fcd091e8f0
@ -1,5 +1,3 @@
|
|||||||
import Base: *
|
|
||||||
|
|
||||||
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
||||||
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
||||||
|
|
||||||
@ -66,19 +64,33 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
|||||||
|
|
||||||
# BLAS
|
# BLAS
|
||||||
|
|
||||||
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
for f in :[*, Ac_mul_B].args
|
||||||
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
@eval begin
|
||||||
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
import Base.$f
|
||||||
|
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
||||||
|
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
|
||||||
|
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
||||||
|
|
||||||
a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
|
$f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||||
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b))
|
$f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
|
||||||
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
|
$f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||||
|
|
||||||
|
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||||
|
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
|
||||||
|
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
||||||
@back(a, A_mul_Bt(Δ, data(b)))
|
@back(a, A_mul_Bt(Δ, data(b)))
|
||||||
@back(b, At_mul_B(data(a), Δ))
|
@back(b, At_mul_B(data(a), Δ))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||||
|
@back(a, A_mul_Bt(Δ, data(b))')
|
||||||
|
@back(b, *(data(a), Δ))
|
||||||
|
end
|
||||||
|
|
||||||
# Fast path for matrix-vector
|
# Fast path for matrix-vector
|
||||||
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
|
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
|
||||||
if isleaf(W)
|
if isleaf(W)
|
||||||
|
@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||||
|
|
||||||
|
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||||
|
|
||||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||||
|
|
||||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||||
|
Loading…
Reference in New Issue
Block a user