hook into bcasting

This commit is contained in:
Dhairya Gandhi 2019-11-07 16:53:41 +05:30
parent 7c90fb469d
commit a4a987f0b0

View File

@ -139,27 +139,45 @@ function throttle(f, timeout; leading=true, trailing=false)
end
end
import Base: +, -, reshape, size
import Base.Broadcast: broadcasted
import Base: +, -, *, reshape, size
import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle
"""
Zeros()
Zeros(size...)
Zeros(Type, size...)
Acts as a stand-in for an array of zeros that can be
used during training which is ignored by the optimisers.
Used to turn bias off for a forward pass of a layer.
Useful to turn bias off for a forward pass of a layer.
!!! warning
Zeros acts a scalar while broadcasting, so does not
expand dims. Checks for shape compatibility by default.
## Examples
```julia
julia> Flux.Zeros(3,3)
3×3 Flux.Zeros{Bool,2}:
false false false
false false false
false false false
julia> Flux.Zeros(Float32, 3,3)
3×3 Flux.Zeros{Float32,2}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
julia> rand(3,3) .+ Flux.Zeros()
3×3 Array{Float64,2}:
0.198739 0.490459 0.785386
0.779074 0.39986 0.66383
0.854981 0.447292 0.314497
julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros())
Conv((2, 2), 1=>3)
```
"""
@ -170,14 +188,15 @@ end
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
Zeros(sz::Integer...) = Zeros(Bool, sz...)
+(a::Union{AbstractVecOrMat, Number}, ::Zeros) = a
Base.size(xs::Zeros) = xs.size
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
Base.getindex(xs::Zeros{T,N}, i::Int) where {T,N} = zero(T)
Base.IndexStyle(::Type{<:Zeros}) = IndexCartesian()
Base.getindex(xs::Zeros{T,N}, I::Vararg{Int, N}) where {T,N} = zero(T)
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
Zeros(T, inds.stop)
Base.setindex(xs::Zeros, args...) =
error("setindex disallowed on Zeros Array")
Base.setindex!(xs::Zeros, args...) =
@ -185,17 +204,40 @@ Base.setindex!(xs::Zeros, args...) =
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
# Ignore during backwards pass
@adjoint reshape(xs::Zeros{T}, dims...) where T =
reshape(xs, dims...), _ -> nothing
# Define basic ops
for f in (:+, :-)
@eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a
end
Base.:+(a::Zeros, b::AbstractArray) = b
Base.:-(a::Zeros, b::AbstractArray) = -b
Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a)
Base.:*(a::Zeros, b::AbstractArray) = zero(a)
broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr
broadcasted(::typeof(-), arr::AbstractArray, ::Zeros) = arr
broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr)
# Hook into broadcasting API - to allow using as a regular array
Base.BroadcastStyle(::Type{<:Zeros}) = Broadcast.ArrayStyle{Zeros}()
Broadcast.broadcastable(xs::Zeros) = xs
Base.BroadcastStyle(::Broadcast.ArrayStyle{Zeros}, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.ArrayStyle{Zeros}()
function Base.similar(bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}, ::Type{T}) where T
similar(Array{T}, axes(bc))
end
Base.copy(xs::Zeros{T,N}) where {T,N} = Zeros(T, size(xs)...)
isZeros(x::Zeros) = true
isZeros(x) = false
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}})
bc = Broadcast.flatten(bc)
i = isZeros(first(bc.args)) ? 2 : 1 # findfirst(!isZeros, bc.args)
dest .= bc.args[i]
end
"""
@jit ...