make Zeros a dimensionlesss number

This commit is contained in:
Dhairya Gandhi 2019-10-22 16:11:27 +05:30
parent c85bad4427
commit 4a183aeaf0

View File

@ -139,30 +139,40 @@ function throttle(f, timeout; leading=true, trailing=false)
end
end
import Base: +, reshape, size
import Base: +, -, reshape, size
import Base.Broadcast: broadcasted
"""
Zeros()
Zeros(T, a::Union{Colon, Int}...)
Acts as a stand-in for an array of zeros that can be used during training which is
ignored by the optimisers.
Acts as a stand-in for an array of zeros that can be
used during training which is ignored by the optimisers.
Used to turn bias off for a forward pass of a layer.
## Examples
```julia
julia> rand(3,3) .+ Flux.Zeros()
3×3 Array{Float64,2}:
0.198739 0.490459 0.785386
0.779074 0.39986 0.66383
0.854981 0.447292 0.314497
julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
Conv((2, 2), 1=>3)
```
"""
struct Zeros{T} <: Number
size::Tuple
struct Zeros <: Number end
for f in (:+, :-)
@eval $f(a::Union{Number, Zeros}, b::Zeros) = a
end
Base.:*(a::Union{Number, Zeros}, b::Zeros) = zero(a)
Zeros(::Type{T}, sz...) where T = Zeros{T}(sz)
Zeros(sz::Union{Integer, Colon}...) = Zeros(Bool, sz...)
+(a::Number, ::Zeros) = a
+(::Zeros, a::Number) = a
size(xs::Zeros) = xs.size
reshape(z::Zeros{T}, args...) where T = Zeros(T, args...)
@adjoint reshape(xs::Zeros{T}, dims...) where T =
Zeros(T, dims...), Δ -> (Zeros(T, size(xs)...), map(_ -> nothing, dims)...)
broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr
broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr)
Base.reshape(xs::Zeros, args...) = xs
@adjoint reshape(xs::Zeros, dims...) = reshape(xs, dims...), _ -> nothing
"""
@jit ...