rewrite broadcast
This commit is contained in:
parent
e13d28a7a2
commit
86cf22675f
@ -327,35 +327,33 @@ end
|
||||
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
_size(x::AbstractArray) = size(x)
|
||||
_size(x) = ()
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
unbroadcast(x::Tuple, Δ) =
|
||||
x == size(Δ) ? Δ :
|
||||
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
|
||||
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
|
||||
|
||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
||||
dual(x, p) = x
|
||||
dual(x::Real, p) = Dual(x, p)
|
||||
|
||||
function getpartial(Δ, x, i)
|
||||
@inbounds p = getindex(partials(x), i)
|
||||
return Δ * p
|
||||
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
||||
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
||||
return Δ * f(dargs...).partials[1]
|
||||
end
|
||||
|
||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||
sizes = _size.(args)
|
||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N)))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
y = value.(out)
|
||||
back = function (Δ_)
|
||||
Δ = data(Δ_)
|
||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N))
|
||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
||||
nobacksies(:broadcast, dxs)
|
||||
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
||||
y = broadcast(f, data.(args)...)
|
||||
eltype(y) <: Real || return y
|
||||
eltype(y) == Bool && return y
|
||||
function back(Δ)
|
||||
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
|
||||
dxs = unbroadcast.(args, Δargs)
|
||||
return nobacksies(:broadcast, dxs)
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, tracker.(args)), y)
|
||||
|
Loading…
Reference in New Issue
Block a user