fix #458
This commit is contained in:
parent
bbccdb3eec
commit
5f99e5775a
@ -82,6 +82,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||
|
||||
@grad function view(x::AbstractArray, inds...)
|
||||
view(data(x), inds...), function (Δ)
|
||||
grad_output = fill!(similar(data(x)), 0)
|
||||
subgrad = view(grad_output, inds...)
|
||||
setindex!(subgrad, Δ, :)
|
||||
(grad_output, map(_->nothing, inds)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
Loading…
Reference in New Issue
Block a user