From 0b89e1374cf1c0013685c64c88de972ab7476236 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 28 Aug 2017 01:40:59 +0100 Subject: [PATCH] gpu-friendly --- src/tracker/lib.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 76f420d3..23f4a670 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -72,10 +72,14 @@ unbroadcast(x, Δ) = size(x) == size(Δ) ? Δ : sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))) -function back!(b::Broadcasted, Δ, args...) - Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args)) +function getpartial(Δ, x, i) + @inbounds p = getindex(partials(x), i) + return Δ * p +end + +function back!(b::Broadcasted, Δ, args::Vararg{Any,N}) where N + Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N}) foreach((x, Δ) -> @back!(x, unbroadcast(x, Δ)), args, Δargs) - return end Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray