From 4a183aeaf02a9de9a98f21ee5eddfd0e7f8219f4 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 22 Oct 2019 16:11:27 +0530 Subject: [PATCH] make Zeros a dimensionlesss number --- src/utils.jl | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 9e095811..ee5f2db7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -139,30 +139,40 @@ function throttle(f, timeout; leading=true, trailing=false) end end -import Base: +, reshape, size +import Base: +, -, reshape, size +import Base.Broadcast: broadcasted """ Zeros() - Zeros(T, a::Union{Colon, Int}...) -Acts as a stand-in for an array of zeros that can be used during training which is -ignored by the optimisers. +Acts as a stand-in for an array of zeros that can be +used during training which is ignored by the optimisers. + +Used to turn bias off for a forward pass of a layer. + +## Examples + +```julia +julia> rand(3,3) .+ Flux.Zeros() +3×3 Array{Float64,2}: + 0.198739 0.490459 0.785386 + 0.779074 0.39986 0.66383 + 0.854981 0.447292 0.314497 + +julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) +Conv((2, 2), 1=>3) +``` """ -struct Zeros{T} <: Number - size::Tuple +struct Zeros <: Number end +for f in (:+, :-) + @eval $f(a::Union{Number, Zeros}, b::Zeros) = a end +Base.:*(a::Union{Number, Zeros}, b::Zeros) = zero(a) -Zeros(::Type{T}, sz...) where T = Zeros{T}(sz) -Zeros(sz::Union{Integer, Colon}...) = Zeros(Bool, sz...) - -+(a::Number, ::Zeros) = a -+(::Zeros, a::Number) = a - -size(xs::Zeros) = xs.size -reshape(z::Zeros{T}, args...) where T = Zeros(T, args...) - -@adjoint reshape(xs::Zeros{T}, dims...) where T = - Zeros(T, dims...), Δ -> (Zeros(T, size(xs)...), map(_ -> nothing, dims)...) +broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr +broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr) +Base.reshape(xs::Zeros, args...) = xs +@adjoint reshape(xs::Zeros, dims...) = reshape(xs, dims...), _ -> nothing """ @jit ...