From 29787eba452a0e12e7c152fe7ded67393f18a8b7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 12 Dec 2017 17:23:15 +0000 Subject: [PATCH] fixes #114 --- src/tracker/lib.jl | 9 +++++++++ test/tracker.jl | 2 ++ 2 files changed, 11 insertions(+) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index ab250e39..f3221bd8 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -58,6 +58,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) +LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) + +function back(::typeof(dot), Δ, xs, ys) + @back(xs, Δ.*ys) + @back(ys, Δ.*xs) +end + # Hacks to get std working Base.std(x::TrackedArray; mean = Base.mean(x)) = sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) diff --git a/test/tracker.jl b/test/tracker.jl index 7d9ef4f5..ac031915 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -38,6 +38,8 @@ end @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5)) +@test gradtest((x, y) -> x .* y, rand(5), rand(5)) + @test gradtest(rand(5)) do x y = x.^2 2y + x