use array to define Zeros
This commit is contained in:
parent
4a183aeaf0
commit
7c90fb469d
35
src/utils.jl
35
src/utils.jl
@ -163,16 +163,39 @@ julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
|
||||
Conv((2, 2), 1=>3)
|
||||
```
|
||||
"""
|
||||
struct Zeros <: Number end
|
||||
for f in (:+, :-)
|
||||
@eval $f(a::Union{Number, Zeros}, b::Zeros) = a
|
||||
struct Zeros{T,N} <: AbstractArray{T,N}
|
||||
size::Tuple
|
||||
end
|
||||
Base.:*(a::Union{Number, Zeros}, b::Zeros) = zero(a)
|
||||
|
||||
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
|
||||
Zeros(sz::Integer...) = Zeros(Bool, sz...)
|
||||
|
||||
+(a::Union{AbstractVecOrMat, Number}, ::Zeros) = a
|
||||
|
||||
Base.size(xs::Zeros) = xs.size
|
||||
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
|
||||
|
||||
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
|
||||
|
||||
Base.getindex(xs::Zeros{T,N}, i::Int) where {T,N} = zero(T)
|
||||
Base.setindex(xs::Zeros, args...) =
|
||||
error("setindex disallowed on Zeros Array")
|
||||
Base.setindex!(xs::Zeros, args...) =
|
||||
error("setindex! disallowed on Zeros Array")
|
||||
|
||||
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
|
||||
|
||||
@adjoint reshape(xs::Zeros{T}, dims...) where T =
|
||||
reshape(xs, dims...), _ -> nothing
|
||||
|
||||
for f in (:+, :-)
|
||||
@eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a
|
||||
end
|
||||
Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a)
|
||||
|
||||
broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr
|
||||
broadcasted(::typeof(-), arr::AbstractArray, ::Zeros) = arr
|
||||
broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr)
|
||||
Base.reshape(xs::Zeros, args...) = xs
|
||||
@adjoint reshape(xs::Zeros, dims...) = reshape(xs, dims...), _ -> nothing
|
||||
|
||||
"""
|
||||
@jit ...
|
||||
|
Loading…
Reference in New Issue
Block a user