diff --git a/src/tracker/array.jl b/src/tracker/array.jl index c950413c..bb42e70f 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -291,6 +291,7 @@ function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs) end +Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray