Ac_mul_B derivatives

This commit is contained in:
Mike J Innes 2017-11-08 22:00:19 +00:00
parent d4229c4815
commit fcd091e8f0
2 changed files with 22 additions and 8 deletions

View File

@ -1,5 +1,3 @@
import Base: *
toarray(xs::AbstractArray, ys::AbstractArray) = ys toarray(xs::AbstractArray, ys::AbstractArray) = ys
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
@ -66,19 +64,33 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
# BLAS # BLAS
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) for f in :[*, Ac_mul_B].args
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) @eval begin
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) $f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b)) $f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) $f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
end
end
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back(a, A_mul_Bt(Δ, data(b))) @back(a, A_mul_Bt(Δ, data(b)))
@back(b, At_mul_B(data(a), Δ)) @back(b, At_mul_B(data(a), Δ))
end end
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))')
@back(b, *(data(a), Δ))
end
# Fast path for matrix-vector # Fast path for matrix-vector
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector) function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
if isleaf(W) if isleaf(W)

View File

@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), 3)