fix broadcasting
This commit is contained in:
parent
e486c50610
commit
adc216f182
@ -327,6 +327,9 @@ end
|
|||||||
|
|
||||||
using ForwardDiff: Dual, partials, value
|
using ForwardDiff: Dual, partials, value
|
||||||
|
|
||||||
|
_size(x::AbstractArray) = size(x)
|
||||||
|
_size(x) = ()
|
||||||
|
|
||||||
dualify(xs, n) = xs
|
dualify(xs, n) = xs
|
||||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
||||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
dualify(xs::Real, ps) = Dual(xs, ps)
|
||||||
@ -343,7 +346,7 @@ function getpartial(Δ, x, i)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||||
sizes = size.(args)
|
sizes = _size.(args)
|
||||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||||
out = broadcast(f, dargs...)
|
out = broadcast(f, dargs...)
|
||||||
eltype(out) <: Dual || return out
|
eltype(out) <: Dual || return out
|
||||||
@ -358,14 +361,14 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
|
|||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
using Base.Broadcast: BroadcastStyle
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
|
||||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
|
||||||
|
|
||||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...)
|
struct TrackedStyle <: BroadcastStyle end
|
||||||
|
|
||||||
|
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||||
|
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||||
|
|
||||||
|
function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
|
||||||
|
bc = Broadcast.flatten(bc)
|
||||||
|
∇broadcast(bc.f, bc.args...)
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user