broadcast efficiency
This commit is contained in:
parent
e763c342ee
commit
80af9a3830
@ -78,8 +78,8 @@ Base.:-(xs::TrackedArray) = track(-, xs)
|
|||||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||||
|
|
||||||
@grad transpose(xs) = xs.', Δ -> (trim(xs, Δ.'),)
|
@grad transpose(xs) = xs.', Δ -> (reshape(Δ.', size(xs)),)
|
||||||
@grad ctranspose(xs) = xs', Δ -> (trim(xs, Δ'),)
|
@grad ctranspose(xs) = xs', Δ -> (reshape(Δ', size(xs)),)
|
||||||
|
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||||
@ -347,13 +347,11 @@ dualify(xs, n) = xs
|
|||||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
||||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
dualify(xs::Real, ps) = Dual(xs, ps)
|
||||||
|
|
||||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
|
unbroadcast(x::Tuple, Δ) =
|
||||||
|
x == size(Δ) ? Δ :
|
||||||
|
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||||
|
|
||||||
unbroadcast(x::AbstractArray, Δ) =
|
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
||||||
size(x) == size(Δ) ? Δ :
|
|
||||||
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
|
|
||||||
|
|
||||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
|
||||||
|
|
||||||
function getpartial(Δ, x, i)
|
function getpartial(Δ, x, i)
|
||||||
@inbounds p = getindex(partials(x), i)
|
@inbounds p = getindex(partials(x), i)
|
||||||
@ -361,13 +359,14 @@ function getpartial(Δ, x, i)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||||
|
sizes = size.(args)
|
||||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||||
out = broadcast(f, dargs...)
|
out = broadcast(f, dargs...)
|
||||||
eltype(out) <: Dual || return out
|
eltype(out) <: Dual || return out
|
||||||
y = value.(out)
|
y = value.(out)
|
||||||
back = function (Δ)
|
back = function (Δ)
|
||||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
||||||
map((x, Δ) -> unbroadcast(x, Δ), args, Δargs)
|
map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
||||||
end
|
end
|
||||||
# So we can return non-tracked arrays
|
# So we can return non-tracked arrays
|
||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
|
Loading…
Reference in New Issue
Block a user