diff --git a/src/Flux.jl b/src/Flux.jl index 5799fe42..90dcb630 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -27,6 +27,7 @@ using CuArrays const use_cuda = Ref(false) include("utils.jl") +include("zeros.jl") include("onehot.jl") include("functor.jl") diff --git a/src/utils.jl b/src/utils.jl index c321bd91..7842c961 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -321,110 +321,6 @@ function throttle(f, timeout; leading=true, trailing=false) end end -import Base: +, -, *, reshape, size -import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle - -""" - Zeros() - Zeros(size...) - Zeros(Type, size...) - -Acts as a stand-in for an array of zeros that can be -used during training which is ignored by the optimisers. - -Useful to turn bias off for a forward pass of a layer. - -## Examples - -```julia -julia> Flux.Zeros(3,3) -3×3 Flux.Zeros{Bool,2}: - false false false - false false false - false false false - -julia> Flux.Zeros(Float32, 3,3) -3×3 Flux.Zeros{Float32,2}: - 0.0 0.0 0.0 - 0.0 0.0 0.0 - 0.0 0.0 0.0 - -julia> rand(3,3) .+ Flux.Zeros() -3×3 Array{Float64,2}: - 0.198739 0.490459 0.785386 - 0.779074 0.39986 0.66383 - 0.854981 0.447292 0.314497 - -julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros()) -Conv((2, 2), 1=>3) -``` -""" -struct Zeros{T,N} <: AbstractArray{T,N} - size::Tuple -end - -Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz) -Zeros(sz::Integer...) = Zeros(Bool, sz...) - -Base.size(xs::Zeros) = xs.size -Base.axes(xs::Zeros) = Base.OneTo.(size(xs)) - -Base.IndexStyle(::Type{<:Zeros}) = IndexLinear() - -Base.getindex(xs::Zeros{T,N}, I::Int) where {T,N} = zero(T) -Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} = - Zeros(T, inds.stop) - -Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) - -@adjoint reshape(xs::Zeros{T}, dims...) where T = - reshape(xs, dims...), _ -> nothing - -# Define basic ops -for f in (:+, :-) - @eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) - @assert size(a) == size(b) throw(DimensionMismatch("dimensions must match")) - a - end -end - -+(a::Zeros, b::AbstractArray) = b + a --(a::Zeros, b::AbstractArray) = -b + a - -Base.copy(xs::Zeros{T,N}) where {T,N} = xs - -# Define broadcasting behaviour -for op in (:+, :-) - @eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros) - bs = Broadcast.broadcast_shape(size(a), size(b)) - size(a) == bs && return a - sz = similar(a, bs) - sz .= a - end -end - -broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a) -broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a) - -function broadcasted(::typeof(*), a::AbstractArray, b::Zeros) - Zeros(Broadcast.broadcast_shape(size(a), size(b))...) -end - -broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(*, b, a) - -for op in (:+, :-, :*) - @eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...) -end - -# Some opportunities to avoid scalar indexing, intermediaries -broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a -broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b -broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a -broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b -broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a) -broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) -broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) - """ @jit ... diff --git a/src/zeros.jl b/src/zeros.jl new file mode 100644 index 00000000..d281d3eb --- /dev/null +++ b/src/zeros.jl @@ -0,0 +1,103 @@ +import Base: +, -, *, reshape, size +import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle + +""" + Zeros() + Zeros(size...) + Zeros(Type, size...) + +Acts as a stand-in for an array of zeros that can be +used during training which is ignored by the optimisers. + +Useful to turn bias off for a forward pass of a layer. + +## Examples + +```julia +julia> Flux.Zeros(3,3) +3×3 Flux.Zeros{Bool,2}: + false false false + false false false + false false false + +julia> Flux.Zeros(Float32, 3,3) +3×3 Flux.Zeros{Float32,2}: + 0.0 0.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 + +julia> rand(3,3) .+ Flux.Zeros() +3×3 Array{Float64,2}: + 0.198739 0.490459 0.785386 + 0.779074 0.39986 0.66383 + 0.854981 0.447292 0.314497 + +julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros()) +Conv((2, 2), 1=>3) +``` +""" +struct Zeros{T,N} <: AbstractArray{T,N} + size::Tuple +end + +Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz) +Zeros(sz::Integer...) = Zeros(Bool, sz...) + +Base.size(xs::Zeros) = xs.size +Base.axes(xs::Zeros) = Base.OneTo.(size(xs)) + +Base.IndexStyle(::Type{<:Zeros}) = IndexLinear() + +Base.getindex(xs::Zeros{T,N}, I::Int) where {T,N} = zero(T) +Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} = + Zeros(T, inds.stop) + +Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) + +@adjoint reshape(xs::Zeros{T}, dims...) where T = + reshape(xs, dims...), _ -> nothing + +# Define basic ops +for f in (:+, :-) + @eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) + @assert size(a) == size(b) throw(DimensionMismatch("dimensions must match")) + a + end +end + ++(a::Zeros, b::AbstractArray) = b + a +-(a::Zeros, b::AbstractArray) = -b + a + +Base.copy(xs::Zeros{T,N}) where {T,N} = xs + +# Define broadcasting behaviour +for op in (:+, :-) + @eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros) + bs = Broadcast.broadcast_shape(size(a), size(b)) + size(a) == bs && return a + sz = similar(a, bs) + sz .= a + end +end + +broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a) +broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a) + +function broadcasted(::typeof(*), a::AbstractArray, b::Zeros) + Zeros(Broadcast.broadcast_shape(size(a), size(b))...) +end + +broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(*, b, a) + +for op in (:+, :-, :*) + @eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...) +end + +# Some opportunities to avoid scalar indexing, intermediaries +broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b +broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b +broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a) +broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) +broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) \ No newline at end of file