hook into bcasting

This commit is contained in:
Dhairya Gandhi 2019-11-07 16:53:41 +05:30
parent 7c90fb469d
commit a4a987f0b0

View File

@ -139,27 +139,45 @@ function throttle(f, timeout; leading=true, trailing=false)
end end
end end
import Base: +, -, reshape, size import Base: +, -, *, reshape, size
import Base.Broadcast: broadcasted import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle
""" """
Zeros() Zeros()
Zeros(size...)
Zeros(Type, size...)
Acts as a stand-in for an array of zeros that can be Acts as a stand-in for an array of zeros that can be
used during training which is ignored by the optimisers. used during training which is ignored by the optimisers.
Used to turn bias off for a forward pass of a layer. Useful to turn bias off for a forward pass of a layer.
!!! warning
Zeros acts a scalar while broadcasting, so does not
expand dims. Checks for shape compatibility by default.
## Examples ## Examples
```julia ```julia
julia> Flux.Zeros(3,3)
3×3 Flux.Zeros{Bool,2}:
false false false
false false false
false false false
julia> Flux.Zeros(Float32, 3,3)
3×3 Flux.Zeros{Float32,2}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
julia> rand(3,3) .+ Flux.Zeros() julia> rand(3,3) .+ Flux.Zeros()
3×3 Array{Float64,2}: 3×3 Array{Float64,2}:
0.198739 0.490459 0.785386 0.198739 0.490459 0.785386
0.779074 0.39986 0.66383 0.779074 0.39986 0.66383
0.854981 0.447292 0.314497 0.854981 0.447292 0.314497
julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros())
Conv((2, 2), 1=>3) Conv((2, 2), 1=>3)
``` ```
""" """
@ -170,14 +188,15 @@ end
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz) Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
Zeros(sz::Integer...) = Zeros(Bool, sz...) Zeros(sz::Integer...) = Zeros(Bool, sz...)
+(a::Union{AbstractVecOrMat, Number}, ::Zeros) = a
Base.size(xs::Zeros) = xs.size Base.size(xs::Zeros) = xs.size
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
Base.axes(xs::Zeros) = Base.OneTo.(size(xs)) Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
Base.getindex(xs::Zeros{T,N}, i::Int) where {T,N} = zero(T) Base.IndexStyle(::Type{<:Zeros}) = IndexCartesian()
Base.getindex(xs::Zeros{T,N}, I::Vararg{Int, N}) where {T,N} = zero(T)
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
Zeros(T, inds.stop)
Base.setindex(xs::Zeros, args...) = Base.setindex(xs::Zeros, args...) =
error("setindex disallowed on Zeros Array") error("setindex disallowed on Zeros Array")
Base.setindex!(xs::Zeros, args...) = Base.setindex!(xs::Zeros, args...) =
@ -185,17 +204,40 @@ Base.setindex!(xs::Zeros, args...) =
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
# Ignore during backwards pass
@adjoint reshape(xs::Zeros{T}, dims...) where T = @adjoint reshape(xs::Zeros{T}, dims...) where T =
reshape(xs, dims...), _ -> nothing reshape(xs, dims...), _ -> nothing
# Define basic ops
for f in (:+, :-) for f in (:+, :-)
@eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a @eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a
end end
Base.:+(a::Zeros, b::AbstractArray) = b
Base.:-(a::Zeros, b::AbstractArray) = -b
Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a) Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a)
Base.:*(a::Zeros, b::AbstractArray) = zero(a)
broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr # Hook into broadcasting API - to allow using as a regular array
broadcasted(::typeof(-), arr::AbstractArray, ::Zeros) = arr Base.BroadcastStyle(::Type{<:Zeros}) = Broadcast.ArrayStyle{Zeros}()
broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr) Broadcast.broadcastable(xs::Zeros) = xs
Base.BroadcastStyle(::Broadcast.ArrayStyle{Zeros}, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.ArrayStyle{Zeros}()
function Base.similar(bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}, ::Type{T}) where T
similar(Array{T}, axes(bc))
end
Base.copy(xs::Zeros{T,N}) where {T,N} = Zeros(T, size(xs)...)
isZeros(x::Zeros) = true
isZeros(x) = false
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}})
bc = Broadcast.flatten(bc)
i = isZeros(first(bc.args)) ? 2 : 1 # findfirst(!isZeros, bc.args)
dest .= bc.args[i]
end
""" """
@jit ... @jit ...