rewrite broadcast

This commit is contained in:
Mike Innes 2018-08-24 14:07:08 +01:00
parent e13d28a7a2
commit 86cf22675f

View File

@ -327,35 +327,33 @@ end
using ForwardDiff: Dual, partials, value using ForwardDiff: Dual, partials, value
_size(x::AbstractArray) = size(x) trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
_size(x) = ()
dualify(xs, n) = xs unbroadcast(x::AbstractArray, Δ) =
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs) size(x) == size(Δ) ? Δ :
dualify(xs::Real, ps) = Dual(xs, ps) length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
unbroadcast(x::Tuple, Δ) = unbroadcast(x::Number, Δ) = sum(Δ)
x == size(Δ) ? Δ : unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x) unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
unbroadcast(x::Tuple{}, Δ) = sum(Δ) dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
function getpartial(Δ, x, i) function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
@inbounds p = getindex(partials(x), i) dargs = ntuple(j -> dual(args[j], i==j), Val(N))
return Δ * p return Δ * f(dargs...).partials[1]
end end
function ∇broadcast(f, args::Vararg{Any,N}) where N @inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
sizes = _size.(args) y = broadcast(f, data.(args)...)
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N))) eltype(y) <: Real || return y
out = broadcast(f, dargs...) eltype(y) == Bool && return y
eltype(out) <: Dual || return out function back(Δ)
y = value.(out) Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
back = function (Δ_) dxs = unbroadcast.(args, Δargs)
Δ = data(Δ_) return nobacksies(:broadcast, dxs)
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N))
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
nobacksies(:broadcast, dxs)
end end
# So we can return non-tracked arrays # So we can return non-tracked arrays
track(Call(back, tracker.(args)), y) track(Call(back, tracker.(args)), y)