This commit is contained in:
Roger-luo 2018-10-24 15:40:10 -04:00
parent bbccdb3eec
commit 5f99e5775a

View File

@ -82,6 +82,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
end
end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@grad function view(x::AbstractArray, inds...)
view(data(x), inds...), function (Δ)
grad_output = fill!(similar(data(x)), 0)
subgrad = view(grad_output, inds...)
setindex!(subgrad, Δ, :)
(grad_output, map(_->nothing, inds)...)
end
end
Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,)