From 86cf22675fd82ddab34d111aa36002a728284f15 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 24 Aug 2018 14:07:08 +0100 Subject: [PATCH] rewrite broadcast --- src/tracker/array.jl | 44 +++++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index cef4463d..5e76ddf4 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -327,35 +327,33 @@ end using ForwardDiff: Dual, partials, value -_size(x::AbstractArray) = size(x) -_size(x) = () +trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -dualify(xs, n) = xs -dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs) -dualify(xs::Real, ps) = Dual(xs, ps) +unbroadcast(x::AbstractArray, Δ) = + size(x) == size(Δ) ? Δ : + length(x) == length(Δ) ? trim(x, Δ) : + trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) -unbroadcast(x::Tuple, Δ) = - x == size(Δ) ? Δ : - reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x) +unbroadcast(x::Number, Δ) = sum(Δ) +unbroadcast(x::Base.RefValue{<:Function}, _) = nothing +unbroadcast(x::Base.RefValue{<:Val}, _) = nothing -unbroadcast(x::Tuple{}, Δ) = sum(Δ) +dual(x, p) = x +dual(x::Real, p) = Dual(x, p) -function getpartial(Δ, x, i) - @inbounds p = getindex(partials(x), i) - return Δ * p +function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N} + dargs = ntuple(j -> dual(args[j], i==j), Val(N)) + return Δ * f(dargs...).partials[1] end -function ∇broadcast(f, args::Vararg{Any,N}) where N - sizes = _size.(args) - dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N))) - out = broadcast(f, dargs...) - eltype(out) <: Dual || return out - y = value.(out) - back = function (Δ_) - Δ = data(Δ_) - Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N)) - dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs) - nobacksies(:broadcast, dxs) +@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N} + y = broadcast(f, data.(args)...) + eltype(y) <: Real || return y + eltype(y) == Bool && return y + function back(Δ) + Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N)) + dxs = unbroadcast.(args, Δargs) + return nobacksies(:broadcast, dxs) end # So we can return non-tracked arrays track(Call(back, tracker.(args)), y)