2017-08-19 09:14:50 +00:00
|
|
|
import Base: *
|
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...))
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
|
|
|
Δ′ = zeros(xs)
|
|
|
|
Δ′[i...] = Δ
|
2017-08-19 16:40:07 +00:00
|
|
|
@back!(xs, Δ′)
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
|
|
|
|
|
|
|
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
|
|
|
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
|
|
|
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
|
|
|
|
|
|
|
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix)
|
2017-08-19 16:40:07 +00:00
|
|
|
@back!(a, A_mul_Bt(Δ, data(b)))
|
|
|
|
@back!(b, At_mul_B(data(a), Δ))
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
# Broadcasting
|
|
|
|
|
|
|
|
using ForwardDiff: Dual, partials
|
|
|
|
|
|
|
|
struct Broadcasted{T}
|
|
|
|
data::T
|
|
|
|
end
|
|
|
|
|
|
|
|
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
|
|
|
|
|
|
|
|
dualify(xs, n) = xs
|
|
|
|
dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps))
|
|
|
|
|
|
|
|
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
|
|
|
dargs = ntuple(i -> dualify(args[i], ntuple(j -> i==j, Val{N})), Val{N})
|
|
|
|
TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
|
|
|
end
|
|
|
|
|
|
|
|
function back!(b::Broadcasted, Δ, args...)
|
|
|
|
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
2017-08-19 16:40:07 +00:00
|
|
|
map((x, Δ) -> @back!(x, Δ), args, Δargs)
|
2017-08-19 09:14:50 +00:00
|
|
|
return
|
|
|
|
end
|
2017-08-19 15:02:19 +00:00
|
|
|
|
|
|
|
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
|
|
|
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
|
|
|
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
|
|
|
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
|
|
|
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
|
|
|
|
|
|
|
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
|