Merge pull request #461 from Roger-luo/roger-patch-1

Support view for TrackedArray
This commit is contained in:
Mike J Innes 2018-10-30 15:24:05 +00:00 committed by GitHub
commit 9312536b96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 0 deletions

View File

@ -82,6 +82,17 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
end end
end end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@grad function view(x::AbstractArray, inds...)
view(data(x), inds...), function (Δ)
grad_output = zero(x)
subgrad = view(grad_output, inds...)
subgrad[:] = data(Δ)
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
end
end
Base.:-(xs::TrackedArray) = track(-, xs) Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,) @grad -(xs) = -data(xs), Δ -> (-Δ,)

View File

@ -33,6 +33,11 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5)) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5)) @test gradtest(x -> x', rand(5))
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end
function promotiontest(f, A, B, C) function promotiontest(f, A, B, C)
r0 = f(A, B, C) r0 = f(A, B, C)
r1 = f(param(A), B, C) r1 = f(param(A), B, C)