diff --git a/src/Tracker/Tracker.jl b/src/Tracker/Tracker.jl index 3eaa27fb..4bf9b3fe 100644 --- a/src/Tracker/Tracker.jl +++ b/src/Tracker/Tracker.jl @@ -43,10 +43,17 @@ function back!(x::TrackedArray, Δ) back!(x.f, Δ) end -for f in :[Base.size, Base.ndims, Base.similar].args +# Fallthrough methods + +for f in :[Base.size, Base.ndims].args @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) end +Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = + similar(data(x), dims...) + +Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) + function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true) if repr print(io, "TrackedArray(") @@ -58,4 +65,6 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru end end +include("lib.jl") + end diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index 00fbb407..31e26f66 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -1,10 +1,52 @@ import Base: * -a::TrackedMatrix * b::Union{TrackedMatrix,AbstractMatrix} = Var(Call(*, a, b)) -a::Union{TrackedMatrix,AbstractMatrix} * b::TrackedMatrix = Var(Call(*, a, b)) +Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...)) -function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray) +function back!(::typeof(getindex), Δ, xs::TrackedArray, i...) + Δ′ = zeros(xs) + Δ′[i...] = Δ + back!(xs, Δ′) +end + +Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs)) + +a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) +a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) +a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) + +function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix) back!(a, A_mul_Bt(Δ, data(b))) back!(b, At_mul_B(data(a), Δ)) +end + +# Broadcasting + +using ForwardDiff: Dual, partials + +struct Broadcasted{T} + data::T +end + +(b::Broadcasted)(xs...) = map(x -> x.value, b.data) + +dualify(xs, n) = xs +dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps)) + +function tracked_broadcast(f, args::Vararg{Any,N}) where N + dargs = ntuple(i -> dualify(args[i], ntuple(j -> i==j, Val{N})), Val{N}) + TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...)) +end + +function back!(b::Broadcasted, Δ, args...) + Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args)) + back!.(args, Δargs) return end + +Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray +Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray +Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray +Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = () +Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A) + +Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)