define manual rules
This commit is contained in:
parent
e89b8eba77
commit
eb41715d26
78
src/utils.jl
78
src/utils.jl
@ -191,54 +191,78 @@ Zeros(sz::Integer...) = Zeros(Bool, sz...)
|
|||||||
Base.size(xs::Zeros) = xs.size
|
Base.size(xs::Zeros) = xs.size
|
||||||
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
|
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
|
||||||
|
|
||||||
Base.IndexStyle(::Type{<:Zeros}) = IndexCartesian()
|
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
|
||||||
|
|
||||||
Base.getindex(xs::Zeros{T,N}, I::Vararg{Int, N}) where {T,N} = zero(T)
|
Base.getindex(xs::Zeros{T,N}, I::Int) where {T,N} = zero(T)
|
||||||
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
|
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
|
||||||
Zeros(T, inds.stop)
|
Zeros(T, inds.stop)
|
||||||
|
|
||||||
Base.setindex(xs::Zeros, args...) =
|
Base.setindex(xs::Zeros, args...) =
|
||||||
error("setindex disallowed on Zeros Array")
|
error("setindex disallowed on Zeros Array")
|
||||||
Base.setindex!(xs::Zeros, args...) =
|
Base.setindex!(xs::Zeros, args...) =
|
||||||
error("setindex! disallowed on Zeros Array")
|
error("setindex! disallowed on Zeros Array")
|
||||||
|
|
||||||
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
|
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
|
||||||
|
|
||||||
# Ignore during backwards pass
|
|
||||||
@adjoint reshape(xs::Zeros{T}, dims...) where T =
|
@adjoint reshape(xs::Zeros{T}, dims...) where T =
|
||||||
reshape(xs, dims...), _ -> nothing
|
reshape(xs, dims...), _ -> nothing
|
||||||
|
|
||||||
# Define basic ops
|
# Define basic ops
|
||||||
for f in (:+, :-)
|
for f in (:+, :-)
|
||||||
@eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a
|
@eval function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros)
|
||||||
|
@assert size(a) == size(b) throw(DimensionMismatch("dimensions must match"))
|
||||||
|
a
|
||||||
|
end
|
||||||
end
|
end
|
||||||
Base.:+(a::Zeros, b::AbstractArray) = b
|
|
||||||
Base.:-(a::Zeros, b::AbstractArray) = -b
|
|
||||||
Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a)
|
|
||||||
Base.:*(a::Zeros, b::AbstractArray) = zero(b)
|
|
||||||
|
|
||||||
# Hook into broadcasting API - to allow using as a regular array
|
+(a::Zeros, b::AbstractArray) = b + a
|
||||||
Base.BroadcastStyle(::Type{<:Zeros}) = Broadcast.ArrayStyle{Zeros}()
|
-(a::Zeros, b::AbstractArray) = -b + a
|
||||||
Broadcast.broadcastable(xs::Zeros) = xs
|
|
||||||
Base.BroadcastStyle(::Broadcast.ArrayStyle{Zeros}, ::Broadcast.DefaultArrayStyle{N}) where N =
|
|
||||||
Broadcast.ArrayStyle{Zeros}()
|
|
||||||
|
|
||||||
function Base.similar(bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}, ::Type{T}) where T
|
function *(a::AbstractArray{S,2}, b::Zeros{T,2}) where {T,S}
|
||||||
similar(Array{T}, axes(bc))
|
@assert size(a,2) == size(b,1) throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
|
||||||
|
res = similar(a, size(a,1), size(b,2))
|
||||||
|
res .= zero(S)
|
||||||
|
end
|
||||||
|
|
||||||
|
function *(a::Zeros{T,2}, b::AbstractArray{S,2}) where {T,S}
|
||||||
|
@assert size(a,2) == size(b,1) throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
|
||||||
|
res = similar(b, size(a,1), size(b,2))
|
||||||
|
res .= zero(S)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.copy(xs::Zeros{T,N}) where {T,N} = Zeros(T, size(xs)...)
|
Base.copy(xs::Zeros{T,N}) where {T,N} = Zeros(T, size(xs)...)
|
||||||
|
|
||||||
isZeros(x::Zeros) = true
|
# Define broadcasting behaviour
|
||||||
isZeros(x) = false
|
for op in (:+, :-)
|
||||||
|
@eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros)
|
||||||
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}})
|
sz = similar(a, Broadcast.broadcast_shape(size(a), size(b)))
|
||||||
bc = Broadcast.flatten(bc)
|
sz .= a
|
||||||
|
end
|
||||||
i = isZeros(first(bc.args)) ? 2 : 1 # findfirst(!isZeros, bc.args)
|
|
||||||
dest .= bc.args[i]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(typeof(+), b, a)
|
||||||
|
broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(typeof(+), -b, a)
|
||||||
|
|
||||||
|
function broadcasted(::typeof(*), a::AbstractArray, b::Zeros)
|
||||||
|
sz = similar(a, Broadcast.broadcast_shape(size(a), size(b)))
|
||||||
|
sz .= zero(a)
|
||||||
|
end
|
||||||
|
|
||||||
|
broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(typeof(*), b, a)
|
||||||
|
|
||||||
|
for op in (:+, :-, :*)
|
||||||
|
@eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
|
||||||
|
end
|
||||||
|
|
||||||
|
# Some opportunities to avoid scalar indexing, intermediaries
|
||||||
|
broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a
|
||||||
|
broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b
|
||||||
|
broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a
|
||||||
|
broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b
|
||||||
|
broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a)
|
||||||
|
broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
|
||||||
|
broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@jit ...
|
@jit ...
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user