From eb41715d26998d2ad711f1644ee0f7127dd01b14 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 19 Nov 2019 13:30:33 +0530 Subject: [PATCH] define manual rules --- src/utils.jl | 78 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 27 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ae2910cc..57e62cca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -191,54 +191,78 @@ Zeros(sz::Integer...) = Zeros(Bool, sz...) Base.size(xs::Zeros) = xs.size Base.axes(xs::Zeros) = Base.OneTo.(size(xs)) -Base.IndexStyle(::Type{<:Zeros}) = IndexCartesian() +Base.IndexStyle(::Type{<:Zeros}) = IndexLinear() -Base.getindex(xs::Zeros{T,N}, I::Vararg{Int, N}) where {T,N} = zero(T) +Base.getindex(xs::Zeros{T,N}, I::Int) where {T,N} = zero(T) Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} = - Zeros(T, inds.stop) + Zeros(T, inds.stop) Base.setindex(xs::Zeros, args...) = - error("setindex disallowed on Zeros Array") + error("setindex disallowed on Zeros Array") Base.setindex!(xs::Zeros, args...) = - error("setindex! disallowed on Zeros Array") + error("setindex! disallowed on Zeros Array") Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) -# Ignore during backwards pass @adjoint reshape(xs::Zeros{T}, dims...) where T = - reshape(xs, dims...), _ -> nothing + reshape(xs, dims...), _ -> nothing # Define basic ops for f in (:+, :-) - @eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a + @eval function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) + @assert size(a) == size(b) throw(DimensionMismatch("dimensions must match")) + a + end end -Base.:+(a::Zeros, b::AbstractArray) = b -Base.:-(a::Zeros, b::AbstractArray) = -b -Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a) -Base.:*(a::Zeros, b::AbstractArray) = zero(b) -# Hook into broadcasting API - to allow using as a regular array -Base.BroadcastStyle(::Type{<:Zeros}) = Broadcast.ArrayStyle{Zeros}() -Broadcast.broadcastable(xs::Zeros) = xs -Base.BroadcastStyle(::Broadcast.ArrayStyle{Zeros}, ::Broadcast.DefaultArrayStyle{N}) where N = - Broadcast.ArrayStyle{Zeros}() ++(a::Zeros, b::AbstractArray) = b + a +-(a::Zeros, b::AbstractArray) = -b + a -function Base.similar(bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}, ::Type{T}) where T - similar(Array{T}, axes(bc)) +function *(a::AbstractArray{S,2}, b::Zeros{T,2}) where {T,S} + @assert size(a,2) == size(b,1) throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) + res = similar(a, size(a,1), size(b,2)) + res .= zero(S) +end + +function *(a::Zeros{T,2}, b::AbstractArray{S,2}) where {T,S} + @assert size(a,2) == size(b,1) throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) + res = similar(b, size(a,1), size(b,2)) + res .= zero(S) end Base.copy(xs::Zeros{T,N}) where {T,N} = Zeros(T, size(xs)...) -isZeros(x::Zeros) = true -isZeros(x) = false - -function Base.copyto!(dest::AbstractArray, bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}) - bc = Broadcast.flatten(bc) - - i = isZeros(first(bc.args)) ? 2 : 1 # findfirst(!isZeros, bc.args) - dest .= bc.args[i] +# Define broadcasting behaviour +for op in (:+, :-) + @eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros) + sz = similar(a, Broadcast.broadcast_shape(size(a), size(b))) + sz .= a + end end +broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(typeof(+), b, a) +broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(typeof(+), -b, a) + +function broadcasted(::typeof(*), a::AbstractArray, b::Zeros) + sz = similar(a, Broadcast.broadcast_shape(size(a), size(b))) + sz .= zero(a) +end + +broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(typeof(*), b, a) + +for op in (:+, :-, :*) + @eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...) +end + +# Some opportunities to avoid scalar indexing, intermediaries +broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b +broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b +broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a) +broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) +broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) + """ @jit ...