broadcast fixes
This commit is contained in:
parent
5a023a9ccc
commit
e68b8765b6
@ -370,16 +370,26 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
|
|||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
end
|
end
|
||||||
|
|
||||||
using Base.Broadcast: BroadcastStyle
|
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||||
|
|
||||||
struct TrackedStyle <: BroadcastStyle end
|
struct TrackedStyle <: BroadcastStyle end
|
||||||
|
|
||||||
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||||
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||||
|
|
||||||
function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
|
# We have to re-build the original broadcast struct to get the appropriate array
|
||||||
bc = Broadcast.flatten(bc)
|
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
||||||
∇broadcast(bc.f, bc.args...)
|
broadcast_rebuild(xs) = data(xs)
|
||||||
|
|
||||||
|
broadcast_rebuild(bc::Broadcasted) =
|
||||||
|
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
||||||
|
|
||||||
|
preprocess(x) = x
|
||||||
|
|
||||||
|
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||||
|
bc1 = Broadcast.flatten(bc)
|
||||||
|
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
||||||
|
∇broadcast(bc2.f, bc1.args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
using Requires
|
using Requires
|
||||||
|
Loading…
Reference in New Issue
Block a user