avoid back! fallback

This commit is contained in:
Mike J Innes 2017-08-19 17:40:07 +01:00
parent c9eb58f146
commit f12b1d0ca1
2 changed files with 14 additions and 6 deletions

View File

@ -3,6 +3,7 @@ module Tracker
export track, back!
data(x) = x
istracked(x) = false
struct Call{F,As<:Tuple}
func::F
@ -14,8 +15,7 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
(c::Call)() = c.func(data.(c.args)...)
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
back!(f, Δ) = nothing
back!(::Call{Void}, Δ) = nothing
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
f::Call
@ -37,6 +37,7 @@ TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
track(xs) = TrackedArray(xs)
istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.x
grad(x::TrackedArray) = x.Δ
@ -45,6 +46,13 @@ function back!(x::TrackedArray, Δ)
back!(x.f, Δ)
end
macro back!(x, Δ)
quote
x = $(esc(x))
istracked(x) && back!(x, $(esc(Δ)))
end
end
# Fallthrough methods
for f in :[Base.size, Base.ndims].args

View File

@ -5,7 +5,7 @@ Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...))
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
Δ′ = zeros(xs)
Δ′[i...] = Δ
back!(xs, Δ′)
@back!(xs, Δ′)
end
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
@ -15,8 +15,8 @@ a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix)
back!(a, A_mul_Bt(Δ, data(b)))
back!(b, At_mul_B(data(a), Δ))
@back!(a, A_mul_Bt(Δ, data(b)))
@back!(b, At_mul_B(data(a), Δ))
end
# Broadcasting
@ -39,7 +39,7 @@ end
function back!(b::Broadcasted, Δ, args...)
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
back!.(args, Δargs)
map((x, Δ) -> @back!(x, Δ), args, Δargs)
return
end