avoid back! fallback
This commit is contained in:
parent
c9eb58f146
commit
f12b1d0ca1
@ -3,6 +3,7 @@ module Tracker
|
|||||||
export track, back!
|
export track, back!
|
||||||
|
|
||||||
data(x) = x
|
data(x) = x
|
||||||
|
istracked(x) = false
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
struct Call{F,As<:Tuple}
|
||||||
func::F
|
func::F
|
||||||
@ -14,8 +15,7 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||||||
(c::Call)() = c.func(data.(c.args)...)
|
(c::Call)() = c.func(data.(c.args)...)
|
||||||
|
|
||||||
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
||||||
|
back!(::Call{Void}, Δ) = nothing
|
||||||
back!(f, Δ) = nothing
|
|
||||||
|
|
||||||
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||||
f::Call
|
f::Call
|
||||||
@ -37,6 +37,7 @@ TrackedArray(c::Call) = TrackedArray(c, c())
|
|||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
|
||||||
|
|
||||||
track(xs) = TrackedArray(xs)
|
track(xs) = TrackedArray(xs)
|
||||||
|
istracked(x::TrackedArray) = true
|
||||||
data(x::TrackedArray) = x.x
|
data(x::TrackedArray) = x.x
|
||||||
grad(x::TrackedArray) = x.Δ
|
grad(x::TrackedArray) = x.Δ
|
||||||
|
|
||||||
@ -45,6 +46,13 @@ function back!(x::TrackedArray, Δ)
|
|||||||
back!(x.f, Δ)
|
back!(x.f, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
macro back!(x, Δ)
|
||||||
|
quote
|
||||||
|
x = $(esc(x))
|
||||||
|
istracked(x) && back!(x, $(esc(Δ)))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims].args
|
for f in :[Base.size, Base.ndims].args
|
||||||
|
@ -5,7 +5,7 @@ Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...))
|
|||||||
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||||
Δ′ = zeros(xs)
|
Δ′ = zeros(xs)
|
||||||
Δ′[i...] = Δ
|
Δ′[i...] = Δ
|
||||||
back!(xs, Δ′)
|
@back!(xs, Δ′)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||||
@ -15,8 +15,8 @@ a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
|||||||
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||||
|
|
||||||
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix)
|
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix)
|
||||||
back!(a, A_mul_Bt(Δ, data(b)))
|
@back!(a, A_mul_Bt(Δ, data(b)))
|
||||||
back!(b, At_mul_B(data(a), Δ))
|
@back!(b, At_mul_B(data(a), Δ))
|
||||||
end
|
end
|
||||||
|
|
||||||
# Broadcasting
|
# Broadcasting
|
||||||
@ -39,7 +39,7 @@ end
|
|||||||
|
|
||||||
function back!(b::Broadcasted, Δ, args...)
|
function back!(b::Broadcasted, Δ, args...)
|
||||||
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
||||||
back!.(args, Δargs)
|
map((x, Δ) -> @back!(x, Δ), args, Δargs)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user