From 5f99e5775aec0f2b14b8c277816ca20622993314 Mon Sep 17 00:00:00 2001 From: Roger-luo Date: Wed, 24 Oct 2018 15:40:10 -0400 Subject: [PATCH 1/3] fix #458 --- src/tracker/array.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index c75b5c1c..f13feb77 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -82,6 +82,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) end end +Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) + +@grad function view(x::AbstractArray, inds...) + view(data(x), inds...), function (Δ) + grad_output = fill!(similar(data(x)), 0) + subgrad = view(grad_output, inds...) + setindex!(subgrad, Δ, :) + (grad_output, map(_->nothing, inds)...) + end +end + Base.:-(xs::TrackedArray) = track(-, xs) @grad -(xs) = -data(xs), Δ -> (-Δ,) From a3cda9016c48ee19367e7732bdb560b90ef5fe5b Mon Sep 17 00:00:00 2001 From: Roger-luo Date: Thu, 25 Oct 2018 13:48:33 -0400 Subject: [PATCH 2/3] apply Mike's change --- src/tracker/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index f13feb77..9c89b5f6 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -86,9 +86,9 @@ Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) @grad function view(x::AbstractArray, inds...) view(data(x), inds...), function (Δ) - grad_output = fill!(similar(data(x)), 0) + grad_output = zero(x) subgrad = view(grad_output, inds...) - setindex!(subgrad, Δ, :) + subgrad[:] = Δ (grad_output, map(_->nothing, inds)...) end end From e5d58699e6255ed662104af13f3215dfa922795b Mon Sep 17 00:00:00 2001 From: Roger-luo Date: Fri, 26 Oct 2018 14:06:17 -0400 Subject: [PATCH 3/3] fix and add test --- src/tracker/array.jl | 4 ++-- test/tracker.jl | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 9c89b5f6..a93ca423 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -88,8 +88,8 @@ Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) view(data(x), inds...), function (Δ) grad_output = zero(x) subgrad = view(grad_output, inds...) - subgrad[:] = Δ - (grad_output, map(_->nothing, inds)...) + subgrad[:] = data(Δ) + (nobacksies(:view, grad_output), map(_->nothing, inds)...) end end diff --git a/test/tracker.jl b/test/tracker.jl index 1f5f6240..baa65cce 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -33,6 +33,11 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5)) @test gradtest(x -> x', rand(5)) + +@testset "indexing & slicing" begin + gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) +end + function promotiontest(f, A, B, C) r0 = f(A, B, C) r1 = f(param(A), B, C)