scalar getindex backprop

This commit is contained in:
Mike J Innes 2017-09-03 17:10:23 -04:00
parent 47ba702747
commit 8f4ccdd5ba

View File

@ -3,12 +3,15 @@ import Base: *
toarray(xs::AbstractArray, ys::AbstractArray) = ys
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
unarray(xs) = xs
unarray(xs::AbstractArray{T,0} where T) = xs[]
Base.getindex(xs::TrackedArray, i...) =
TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...]))
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
Δ′ = zeros(xs.x)
Δ′[i...] = Δ
Δ′[i...] = unarray(Δ)
@back!(xs, Δ′)
end
@ -19,6 +22,9 @@ back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ)
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
back!(::typeof(transpose), Δ, xs) = @back!(xs, trim(xs, Δ.'))
back!(::typeof(ctranspose), Δ, xs) = @back!(xs, trim(xs, Δ'))
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))