gpu-friendly
This commit is contained in:
parent
12dc6b66c5
commit
0b89e1374c
@ -72,10 +72,14 @@ unbroadcast(x, Δ) =
|
|||||||
size(x) == size(Δ) ? Δ :
|
size(x) == size(Δ) ? Δ :
|
||||||
sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ)))
|
sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ)))
|
||||||
|
|
||||||
function back!(b::Broadcasted, Δ, args...)
|
function getpartial(Δ, x, i)
|
||||||
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
@inbounds p = getindex(partials(x), i)
|
||||||
|
return Δ * p
|
||||||
|
end
|
||||||
|
|
||||||
|
function back!(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
|
||||||
|
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
|
||||||
foreach((x, Δ) -> @back!(x, unbroadcast(x, Δ)), args, Δargs)
|
foreach((x, Δ) -> @back!(x, unbroadcast(x, Δ)), args, Δargs)
|
||||||
return
|
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
||||||
|
Loading…
Reference in New Issue
Block a user