From f12b1d0ca1cf1d6582584a650676601eb5c03a43 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 19 Aug 2017 17:40:07 +0100 Subject: [PATCH] avoid back! fallback --- src/Tracker/Tracker.jl | 12 ++++++++++-- src/Tracker/lib.jl | 8 ++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/Tracker/Tracker.jl b/src/Tracker/Tracker.jl index 2447f5b5..c6276887 100644 --- a/src/Tracker/Tracker.jl +++ b/src/Tracker/Tracker.jl @@ -3,6 +3,7 @@ module Tracker export track, back! data(x) = x +istracked(x) = false struct Call{F,As<:Tuple} func::F @@ -14,8 +15,7 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) (c::Call)() = c.func(data.(c.args)...) back!(c::Call, Δ) = back!(c.func, Δ, c.args...) - -back!(f, Δ) = nothing +back!(::Call{Void}, Δ) = nothing struct TrackedArray{T,N,A} <: AbstractArray{T,N} f::Call @@ -37,6 +37,7 @@ TrackedArray(c::Call) = TrackedArray(c, c()) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x) track(xs) = TrackedArray(xs) +istracked(x::TrackedArray) = true data(x::TrackedArray) = x.x grad(x::TrackedArray) = x.Δ @@ -45,6 +46,13 @@ function back!(x::TrackedArray, Δ) back!(x.f, Δ) end +macro back!(x, Δ) + quote + x = $(esc(x)) + istracked(x) && back!(x, $(esc(Δ))) + end +end + # Fallthrough methods for f in :[Base.size, Base.ndims].args diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index 31e26f66..55a1cdaf 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -5,7 +5,7 @@ Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...)) function back!(::typeof(getindex), Δ, xs::TrackedArray, i...) Δ′ = zeros(xs) Δ′[i...] = Δ - back!(xs, Δ′) + @back!(xs, Δ′) end Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs)) @@ -15,8 +15,8 @@ a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix) - back!(a, A_mul_Bt(Δ, data(b))) - back!(b, At_mul_B(data(a), Δ)) + @back!(a, A_mul_Bt(Δ, data(b))) + @back!(b, At_mul_B(data(a), Δ)) end # Broadcasting @@ -39,7 +39,7 @@ end function back!(b::Broadcasted, Δ, args...) Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args)) - back!.(args, Δargs) + map((x, Δ) -> @back!(x, Δ), args, Δargs) return end