diff --git a/src/tracker/array.jl b/src/tracker/array.jl index e9fa1a1b..559891da 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -370,16 +370,26 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N track(Call(back, tracker.(args)), y) end -using Base.Broadcast: BroadcastStyle +using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted struct TrackedStyle <: BroadcastStyle end Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle() Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle() -function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle}) - bc = Broadcast.flatten(bc) - ∇broadcast(bc.f, bc.args...) +# We have to re-build the original broadcast struct to get the appropriate array +# style. We need this primarily to support CuArrays' broadcasting fixes. +broadcast_rebuild(xs) = data(xs) + +broadcast_rebuild(bc::Broadcasted) = + broadcasted(bc.f, broadcast_rebuild.(bc.args)...) + +preprocess(x) = x + +function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle}) + bc1 = Broadcast.flatten(bc) + bc2 = Broadcast.flatten(broadcast_rebuild(bc)) + ∇broadcast(bc2.f, bc1.args...) end using Requires