rewrite broadcast

This commit is contained in:
Mike Innes 2018-08-24 14:07:08 +01:00
parent e13d28a7a2
commit 86cf22675f

View File

@ -327,35 +327,33 @@ end
using ForwardDiff: Dual, partials, value
_size(x::AbstractArray) = size(x)
_size(x) = ()
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
dualify(xs, n) = xs
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
dualify(xs::Real, ps) = Dual(xs, ps)
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
unbroadcast(x::Tuple, Δ) =
x == size(Δ) ? Δ :
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
function getpartial(Δ, x, i)
@inbounds p = getindex(partials(x), i)
return Δ * p
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
return Δ * f(dargs...).partials[1]
end
function ∇broadcast(f, args::Vararg{Any,N}) where N
sizes = _size.(args)
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N)))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
y = value.(out)
back = function (Δ_)
Δ = data(Δ_)
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N))
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
nobacksies(:broadcast, dxs)
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
y = broadcast(f, data.(args)...)
eltype(y) <: Real || return y
eltype(y) == Bool && return y
function back(Δ)
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
dxs = unbroadcast.(args, Δargs)
return nobacksies(:broadcast, dxs)
end
# So we can return non-tracked arrays
track(Call(back, tracker.(args)), y)