diff --git a/src/tracker/array.jl b/src/tracker/array.jl index c75b5c1c..f13feb77 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -82,6 +82,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) end end +Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) + +@grad function view(x::AbstractArray, inds...) + view(data(x), inds...), function (Δ) + grad_output = fill!(similar(data(x)), 0) + subgrad = view(grad_output, inds...) + setindex!(subgrad, Δ, :) + (grad_output, map(_->nothing, inds)...) + end +end + Base.:-(xs::TrackedArray) = track(-, xs) @grad -(xs) = -data(xs), Δ -> (-Δ,)