correct broadcasting for addition

This commit is contained in:
Dhairya Gandhi 2020-03-04 18:22:45 +05:30
parent 7e308e77fd
commit d8e44fcc1c
1 changed files with 4 additions and 2 deletions

View File

@ -247,7 +247,7 @@ Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
# Define basic ops # Define basic ops
for f in (:+, :-) for f in (:+, :-)
@eval function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) @eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros)
@assert size(a) == size(b) throw(DimensionMismatch("dimensions must match")) @assert size(a) == size(b) throw(DimensionMismatch("dimensions must match"))
a a
end end
@ -261,7 +261,9 @@ Base.copy(xs::Zeros{T,N}) where {T,N} = xs
# Define broadcasting behaviour # Define broadcasting behaviour
for op in (:+, :-) for op in (:+, :-)
@eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros) @eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros)
sz = similar(a, Broadcast.broadcast_shape(size(a), size(b))) bs = Broadcast.broadcast_shape(size(a), size(b))
size(a) == bs && return a
sz = similar(a, bs)
sz .= a sz .= a
end end
end end