diff --git a/src/tracker/array.jl b/src/tracker/array.jl index c7d1178b..432244ce 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -82,6 +82,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) end end +Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) + +@grad function view(x::AbstractArray, inds...) + view(data(x), inds...), function (Δ) + grad_output = zero(x) + subgrad = view(grad_output, inds...) + subgrad[:] = data(Δ) + (nobacksies(:view, grad_output), map(_->nothing, inds)...) + end +end + Base.:-(xs::TrackedArray) = track(-, xs) @grad -(xs) = -data(xs), Δ -> (-Δ,) diff --git a/test/tracker.jl b/test/tracker.jl index 1f5f6240..baa65cce 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -33,6 +33,11 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5)) @test gradtest(x -> x', rand(5)) + +@testset "indexing & slicing" begin + gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) +end + function promotiontest(f, A, B, C) r0 = f(A, B, C) r1 = f(param(A), B, C)