Merge pull request #461 from Roger-luo/roger-patch-1
Support view for TrackedArray
This commit is contained in:
commit
9312536b96
@ -82,6 +82,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||||
|
|
||||||
|
@grad function view(x::AbstractArray, inds...)
|
||||||
|
view(data(x), inds...), function (Δ)
|
||||||
|
grad_output = zero(x)
|
||||||
|
subgrad = view(grad_output, inds...)
|
||||||
|
subgrad[:] = data(Δ)
|
||||||
|
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||||
|
|
||||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||||
|
@ -33,6 +33,11 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||||
|
|
||||||
@test gradtest(x -> x', rand(5))
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
|
@testset "indexing & slicing" begin
|
||||||
|
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||||
|
end
|
||||||
|
|
||||||
function promotiontest(f, A, B, C)
|
function promotiontest(f, A, B, C)
|
||||||
r0 = f(A, B, C)
|
r0 = f(A, B, C)
|
||||||
r1 = f(param(A), B, C)
|
r1 = f(param(A), B, C)
|
||||||
|
Loading…
Reference in New Issue
Block a user