diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 796b34e9..8031db0f 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -3,12 +3,15 @@ import Base: * toarray(xs::AbstractArray, ys::AbstractArray) = ys toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y +unarray(xs) = xs +unarray(xs::AbstractArray{T,0} where T) = xs[] + Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...])) function back!(::typeof(getindex), Δ, xs::TrackedArray, i...) Δ′ = zeros(xs.x) - Δ′[i...] = Δ + Δ′[i...] = unarray(Δ) @back!(xs, Δ′) end @@ -19,6 +22,9 @@ back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ) Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs)) Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs)) +back!(::typeof(transpose), Δ, xs) = @back!(xs, trim(xs, Δ.')) +back!(::typeof(ctranspose), Δ, xs) = @back!(xs, trim(xs, Δ')) + Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))