From 80af9a3830afbda23b5c8c3b1cc18f129220b23d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 9 Jul 2018 23:40:07 +0100 Subject: [PATCH] broadcast efficiency --- src/tracker/array.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index dbf789ac..ee8c9b39 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -78,8 +78,8 @@ Base.:-(xs::TrackedArray) = track(-, xs) Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) -@grad transpose(xs) = xs.', Δ -> (trim(xs, Δ.'),) -@grad ctranspose(xs) = xs', Δ -> (trim(xs, Δ'),) +@grad transpose(xs) = xs.', Δ -> (reshape(Δ.', size(xs)),) +@grad ctranspose(xs) = xs', Δ -> (reshape(Δ', size(xs)),) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) @@ -347,13 +347,11 @@ dualify(xs, n) = xs dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs) dualify(xs::Real, ps) = Dual(xs, ps) -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)})) +unbroadcast(x::Tuple, Δ) = + x == size(Δ) ? Δ : + reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x) -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ)))) - -unbroadcast(x::Number, Δ) = sum(Δ) +unbroadcast(x::Tuple{}, Δ) = sum(Δ) function getpartial(Δ, x, i) @inbounds p = getindex(partials(x), i) @@ -361,13 +359,14 @@ function getpartial(Δ, x, i) end function ∇broadcast(f, args::Vararg{Any,N}) where N + sizes = size.(args) dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) out = broadcast(f, dargs...) eltype(out) <: Dual || return out y = value.(out) back = function (Δ) Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N}) - map((x, Δ) -> unbroadcast(x, Δ), args, Δargs) + map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs) end # So we can return non-tracked arrays track(Call(back, tracker.(args)), y)