diff --git a/src/tracker/array.jl b/src/tracker/array.jl index d5c04b5c..1a5c6c1a 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -327,6 +327,9 @@ end using ForwardDiff: Dual, partials, value +_size(x::AbstractArray) = size(x) +_size(x) = () + dualify(xs, n) = xs dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs) dualify(xs::Real, ps) = Dual(xs, ps) @@ -343,7 +346,7 @@ function getpartial(Δ, x, i) end function ∇broadcast(f, args::Vararg{Any,N}) where N - sizes = size.(args) + sizes = _size.(args) dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) out = broadcast(f, dargs...) eltype(out) <: Dual || return out @@ -358,14 +361,14 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N track(Call(back, tracker.(args)), y) end -Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray -Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray -Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray -Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray -Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray -Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray -Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray -Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = () -Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A) +using Base.Broadcast: BroadcastStyle -Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...) +struct TrackedStyle <: BroadcastStyle end + +Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle() +Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle() + +function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle}) + bc = Broadcast.flatten(bc) + ∇broadcast(bc.f, bc.args...) +end