diff --git a/src/utils.jl b/src/utils.jl index 155326ab..6e5ab8a2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -139,27 +139,45 @@ function throttle(f, timeout; leading=true, trailing=false) end end -import Base: +, -, reshape, size -import Base.Broadcast: broadcasted +import Base: +, -, *, reshape, size +import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle """ Zeros() + Zeros(size...) + Zeros(Type, size...) Acts as a stand-in for an array of zeros that can be used during training which is ignored by the optimisers. -Used to turn bias off for a forward pass of a layer. +Useful to turn bias off for a forward pass of a layer. + +!!! warning + Zeros acts a scalar while broadcasting, so does not + expand dims. Checks for shape compatibility by default. ## Examples ```julia +julia> Flux.Zeros(3,3) +3×3 Flux.Zeros{Bool,2}: + false false false + false false false + false false false + +julia> Flux.Zeros(Float32, 3,3) +3×3 Flux.Zeros{Float32,2}: + 0.0 0.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 + julia> rand(3,3) .+ Flux.Zeros() 3×3 Array{Float64,2}: 0.198739 0.490459 0.785386 0.779074 0.39986 0.66383 0.854981 0.447292 0.314497 -julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) +julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros()) Conv((2, 2), 1=>3) ``` """ @@ -170,14 +188,15 @@ end Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz) Zeros(sz::Integer...) = Zeros(Bool, sz...) -+(a::Union{AbstractVecOrMat, Number}, ::Zeros) = a - Base.size(xs::Zeros) = xs.size -Base.IndexStyle(::Type{<:Zeros}) = IndexLinear() - Base.axes(xs::Zeros) = Base.OneTo.(size(xs)) -Base.getindex(xs::Zeros{T,N}, i::Int) where {T,N} = zero(T) +Base.IndexStyle(::Type{<:Zeros}) = IndexCartesian() + +Base.getindex(xs::Zeros{T,N}, I::Vararg{Int, N}) where {T,N} = zero(T) +Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} = + Zeros(T, inds.stop) + Base.setindex(xs::Zeros, args...) = error("setindex disallowed on Zeros Array") Base.setindex!(xs::Zeros, args...) = @@ -185,17 +204,40 @@ Base.setindex!(xs::Zeros, args...) = Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) +# Ignore during backwards pass @adjoint reshape(xs::Zeros{T}, dims...) where T = reshape(xs, dims...), _ -> nothing +# Define basic ops for f in (:+, :-) @eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a end +Base.:+(a::Zeros, b::AbstractArray) = b +Base.:-(a::Zeros, b::AbstractArray) = -b Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a) +Base.:*(a::Zeros, b::AbstractArray) = zero(a) -broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr -broadcasted(::typeof(-), arr::AbstractArray, ::Zeros) = arr -broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr) +# Hook into broadcasting API - to allow using as a regular array +Base.BroadcastStyle(::Type{<:Zeros}) = Broadcast.ArrayStyle{Zeros}() +Broadcast.broadcastable(xs::Zeros) = xs +Base.BroadcastStyle(::Broadcast.ArrayStyle{Zeros}, ::Broadcast.DefaultArrayStyle{N}) where N = + Broadcast.ArrayStyle{Zeros}() + +function Base.similar(bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}, ::Type{T}) where T + similar(Array{T}, axes(bc)) +end + +Base.copy(xs::Zeros{T,N}) where {T,N} = Zeros(T, size(xs)...) + +isZeros(x::Zeros) = true +isZeros(x) = false + +function Base.copyto!(dest::AbstractArray, bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}) + bc = Broadcast.flatten(bc) + + i = isZeros(first(bc.args)) ? 2 : 1 # findfirst(!isZeros, bc.args) + dest .= bc.args[i] +end """ @jit ...