scalar getindex backprop
This commit is contained in:
parent
47ba702747
commit
8f4ccdd5ba
@ -3,12 +3,15 @@ import Base: *
|
||||
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
||||
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
||||
|
||||
unarray(xs) = xs
|
||||
unarray(xs::AbstractArray{T,0} where T) = xs[]
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) =
|
||||
TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...]))
|
||||
|
||||
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||
Δ′ = zeros(xs.x)
|
||||
Δ′[i...] = Δ
|
||||
Δ′[i...] = unarray(Δ)
|
||||
@back!(xs, Δ′)
|
||||
end
|
||||
|
||||
@ -19,6 +22,9 @@ back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ)
|
||||
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
|
||||
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
|
||||
|
||||
back!(::typeof(transpose), Δ, xs) = @back!(xs, trim(xs, Δ.'))
|
||||
back!(::typeof(ctranspose), Δ, xs) = @back!(xs, trim(xs, Δ'))
|
||||
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
|
Loading…
Reference in New Issue
Block a user