avoid back! fallback
This commit is contained in:
parent
c9eb58f146
commit
f12b1d0ca1
@ -3,6 +3,7 @@ module Tracker
|
||||
export track, back!
|
||||
|
||||
data(x) = x
|
||||
istracked(x) = false
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
@ -14,8 +15,7 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
||||
(c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
||||
|
||||
back!(f, Δ) = nothing
|
||||
back!(::Call{Void}, Δ) = nothing
|
||||
|
||||
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
f::Call
|
||||
@ -37,6 +37,7 @@ TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
|
||||
|
||||
track(xs) = TrackedArray(xs)
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.x
|
||||
grad(x::TrackedArray) = x.Δ
|
||||
|
||||
@ -45,6 +46,13 @@ function back!(x::TrackedArray, Δ)
|
||||
back!(x.f, Δ)
|
||||
end
|
||||
|
||||
macro back!(x, Δ)
|
||||
quote
|
||||
x = $(esc(x))
|
||||
istracked(x) && back!(x, $(esc(Δ)))
|
||||
end
|
||||
end
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
|
@ -5,7 +5,7 @@ Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...))
|
||||
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||
Δ′ = zeros(xs)
|
||||
Δ′[i...] = Δ
|
||||
back!(xs, Δ′)
|
||||
@back!(xs, Δ′)
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||
@ -15,8 +15,8 @@ a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
||||
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||
|
||||
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix)
|
||||
back!(a, A_mul_Bt(Δ, data(b)))
|
||||
back!(b, At_mul_B(data(a), Δ))
|
||||
@back!(a, A_mul_Bt(Δ, data(b)))
|
||||
@back!(b, At_mul_B(data(a), Δ))
|
||||
end
|
||||
|
||||
# Broadcasting
|
||||
@ -39,7 +39,7 @@ end
|
||||
|
||||
function back!(b::Broadcasted, Δ, args...)
|
||||
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
||||
back!.(args, Δargs)
|
||||
map((x, Δ) -> @back!(x, Δ), args, Δargs)
|
||||
return
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user