de-broadcasting
This commit is contained in:
parent
bafecfede1
commit
56ed6f5680
@ -42,7 +42,8 @@ data(x::TrackedArray) = x.x
|
|||||||
grad(x::TrackedArray) = x.Δ
|
grad(x::TrackedArray) = x.Δ
|
||||||
|
|
||||||
function back!(x::TrackedArray, Δ)
|
function back!(x::TrackedArray, Δ)
|
||||||
x.Δ .+= Δ
|
Δ′ = vec(x.Δ)
|
||||||
|
Δ′ .+= vec(Δ)
|
||||||
back!(x.f, Δ)
|
back!(x.f, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -57,9 +57,13 @@ function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
|||||||
TrackedArray(Call(b, args...), b())
|
TrackedArray(Call(b, args...), b())
|
||||||
end
|
end
|
||||||
|
|
||||||
|
unbroadcast(x, Δ) =
|
||||||
|
size(x) == size(Δ) ? Δ :
|
||||||
|
sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ)))
|
||||||
|
|
||||||
function back!(b::Broadcasted, Δ, args...)
|
function back!(b::Broadcasted, Δ, args...)
|
||||||
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
||||||
map((x, Δ) -> @back!(x, Δ), args, Δargs)
|
foreach((x, Δ) -> @back!(x, unbroadcast(x, Δ)), args, Δargs)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user