use array to define Zeros

This commit is contained in:
Dhairya Gandhi 2019-10-23 20:02:15 +05:30
parent 4a183aeaf0
commit 7c90fb469d

View File

@ -163,16 +163,39 @@ julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
Conv((2, 2), 1=>3) Conv((2, 2), 1=>3)
``` ```
""" """
struct Zeros <: Number end struct Zeros{T,N} <: AbstractArray{T,N}
for f in (:+, :-) size::Tuple
@eval $f(a::Union{Number, Zeros}, b::Zeros) = a
end end
Base.:*(a::Union{Number, Zeros}, b::Zeros) = zero(a)
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
Zeros(sz::Integer...) = Zeros(Bool, sz...)
+(a::Union{AbstractVecOrMat, Number}, ::Zeros) = a
Base.size(xs::Zeros) = xs.size
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
Base.getindex(xs::Zeros{T,N}, i::Int) where {T,N} = zero(T)
Base.setindex(xs::Zeros, args...) =
error("setindex disallowed on Zeros Array")
Base.setindex!(xs::Zeros, args...) =
error("setindex! disallowed on Zeros Array")
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
@adjoint reshape(xs::Zeros{T}, dims...) where T =
reshape(xs, dims...), _ -> nothing
for f in (:+, :-)
@eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a
end
Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a)
broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr
broadcasted(::typeof(-), arr::AbstractArray, ::Zeros) = arr
broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr) broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr)
Base.reshape(xs::Zeros, args...) = xs
@adjoint reshape(xs::Zeros, dims...) = reshape(xs, dims...), _ -> nothing
""" """
@jit ... @jit ...