rewrite broadcast
This commit is contained in:
parent
e13d28a7a2
commit
86cf22675f
@ -327,35 +327,33 @@ end
|
|||||||
|
|
||||||
using ForwardDiff: Dual, partials, value
|
using ForwardDiff: Dual, partials, value
|
||||||
|
|
||||||
_size(x::AbstractArray) = size(x)
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||||
_size(x) = ()
|
|
||||||
|
|
||||||
dualify(xs, n) = xs
|
unbroadcast(x::AbstractArray, Δ) =
|
||||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
size(x) == size(Δ) ? Δ :
|
||||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
length(x) == length(Δ) ? trim(x, Δ) :
|
||||||
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||||
|
|
||||||
unbroadcast(x::Tuple, Δ) =
|
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||||
x == size(Δ) ? Δ :
|
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
|
||||||
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
|
||||||
|
|
||||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
dual(x, p) = x
|
||||||
|
dual(x::Real, p) = Dual(x, p)
|
||||||
|
|
||||||
function getpartial(Δ, x, i)
|
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
||||||
@inbounds p = getindex(partials(x), i)
|
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
||||||
return Δ * p
|
return Δ * f(dargs...).partials[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
||||||
sizes = _size.(args)
|
y = broadcast(f, data.(args)...)
|
||||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N)))
|
eltype(y) <: Real || return y
|
||||||
out = broadcast(f, dargs...)
|
eltype(y) == Bool && return y
|
||||||
eltype(out) <: Dual || return out
|
function back(Δ)
|
||||||
y = value.(out)
|
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
|
||||||
back = function (Δ_)
|
dxs = unbroadcast.(args, Δargs)
|
||||||
Δ = data(Δ_)
|
return nobacksies(:broadcast, dxs)
|
||||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N))
|
|
||||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
|
||||||
nobacksies(:broadcast, dxs)
|
|
||||||
end
|
end
|
||||||
# So we can return non-tracked arrays
|
# So we can return non-tracked arrays
|
||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
|
Loading…
Reference in New Issue
Block a user