From 7c90fb469d19585d63d95aeb28e68041af7e35b7 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 23 Oct 2019 20:02:15 +0530 Subject: [PATCH] use array to define Zeros --- src/utils.jl | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ee5f2db7..155326ab 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -163,16 +163,39 @@ julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) Conv((2, 2), 1=>3) ``` """ -struct Zeros <: Number end -for f in (:+, :-) - @eval $f(a::Union{Number, Zeros}, b::Zeros) = a +struct Zeros{T,N} <: AbstractArray{T,N} + size::Tuple end -Base.:*(a::Union{Number, Zeros}, b::Zeros) = zero(a) + +Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz) +Zeros(sz::Integer...) = Zeros(Bool, sz...) + ++(a::Union{AbstractVecOrMat, Number}, ::Zeros) = a + +Base.size(xs::Zeros) = xs.size +Base.IndexStyle(::Type{<:Zeros}) = IndexLinear() + +Base.axes(xs::Zeros) = Base.OneTo.(size(xs)) + +Base.getindex(xs::Zeros{T,N}, i::Int) where {T,N} = zero(T) +Base.setindex(xs::Zeros, args...) = + error("setindex disallowed on Zeros Array") +Base.setindex!(xs::Zeros, args...) = + error("setindex! disallowed on Zeros Array") + +Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) + +@adjoint reshape(xs::Zeros{T}, dims...) where T = + reshape(xs, dims...), _ -> nothing + +for f in (:+, :-) + @eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a +end +Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a) broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr +broadcasted(::typeof(-), arr::AbstractArray, ::Zeros) = arr broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr) -Base.reshape(xs::Zeros, args...) = xs -@adjoint reshape(xs::Zeros, dims...) = reshape(xs, dims...), _ -> nothing """ @jit ...