From fcd091e8f06fc7a8824c4ca12d38dd23a4da4f08 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 8 Nov 2017 22:00:19 +0000 Subject: [PATCH] Ac_mul_B derivatives --- src/tracker/lib.jl | 28 ++++++++++++++++++++-------- test/tracker.jl | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 2ee5d659..aab26dfe 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -1,5 +1,3 @@ -import Base: * - toarray(xs::AbstractArray, ys::AbstractArray) = ys toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y @@ -66,19 +64,33 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) = # BLAS -a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) -a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) -a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) +for f in :[*, Ac_mul_B].args + @eval begin + import Base.$f + $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) + $f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b)) + $f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) -a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) -a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b)) -a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) + $f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b)) + $f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b)) + $f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b)) + + $f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b)) + $f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b)) + $f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b)) + end +end function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) @back(a, A_mul_Bt(Δ, data(b))) @back(b, At_mul_B(data(a), Δ)) end +function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) + @back(a, A_mul_Bt(Δ, data(b))') + @back(b, *(data(a), Δ)) +end + # Fast path for matrix-vector function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector) if isleaf(W) diff --git a/test/tracker.jl b/test/tracker.jl index 52a73a07..69f37367 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) +@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) + @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> softmax(x).*(1:3), 3)