hook into bcasting
This commit is contained in:
parent
7c90fb469d
commit
a4a987f0b0
66
src/utils.jl
66
src/utils.jl
@ -139,27 +139,45 @@ function throttle(f, timeout; leading=true, trailing=false)
|
||||
end
|
||||
end
|
||||
|
||||
import Base: +, -, reshape, size
|
||||
import Base.Broadcast: broadcasted
|
||||
import Base: +, -, *, reshape, size
|
||||
import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle
|
||||
|
||||
"""
|
||||
Zeros()
|
||||
Zeros(size...)
|
||||
Zeros(Type, size...)
|
||||
|
||||
Acts as a stand-in for an array of zeros that can be
|
||||
used during training which is ignored by the optimisers.
|
||||
|
||||
Used to turn bias off for a forward pass of a layer.
|
||||
Useful to turn bias off for a forward pass of a layer.
|
||||
|
||||
!!! warning
|
||||
Zeros acts a scalar while broadcasting, so does not
|
||||
expand dims. Checks for shape compatibility by default.
|
||||
|
||||
## Examples
|
||||
|
||||
```julia
|
||||
julia> Flux.Zeros(3,3)
|
||||
3×3 Flux.Zeros{Bool,2}:
|
||||
false false false
|
||||
false false false
|
||||
false false false
|
||||
|
||||
julia> Flux.Zeros(Float32, 3,3)
|
||||
3×3 Flux.Zeros{Float32,2}:
|
||||
0.0 0.0 0.0
|
||||
0.0 0.0 0.0
|
||||
0.0 0.0 0.0
|
||||
|
||||
julia> rand(3,3) .+ Flux.Zeros()
|
||||
3×3 Array{Float64,2}:
|
||||
0.198739 0.490459 0.785386
|
||||
0.779074 0.39986 0.66383
|
||||
0.854981 0.447292 0.314497
|
||||
|
||||
julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
|
||||
julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros())
|
||||
Conv((2, 2), 1=>3)
|
||||
```
|
||||
"""
|
||||
@ -170,14 +188,15 @@ end
|
||||
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
|
||||
Zeros(sz::Integer...) = Zeros(Bool, sz...)
|
||||
|
||||
+(a::Union{AbstractVecOrMat, Number}, ::Zeros) = a
|
||||
|
||||
Base.size(xs::Zeros) = xs.size
|
||||
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
|
||||
|
||||
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
|
||||
|
||||
Base.getindex(xs::Zeros{T,N}, i::Int) where {T,N} = zero(T)
|
||||
Base.IndexStyle(::Type{<:Zeros}) = IndexCartesian()
|
||||
|
||||
Base.getindex(xs::Zeros{T,N}, I::Vararg{Int, N}) where {T,N} = zero(T)
|
||||
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
|
||||
Zeros(T, inds.stop)
|
||||
|
||||
Base.setindex(xs::Zeros, args...) =
|
||||
error("setindex disallowed on Zeros Array")
|
||||
Base.setindex!(xs::Zeros, args...) =
|
||||
@ -185,17 +204,40 @@ Base.setindex!(xs::Zeros, args...) =
|
||||
|
||||
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
|
||||
|
||||
# Ignore during backwards pass
|
||||
@adjoint reshape(xs::Zeros{T}, dims...) where T =
|
||||
reshape(xs, dims...), _ -> nothing
|
||||
|
||||
# Define basic ops
|
||||
for f in (:+, :-)
|
||||
@eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a
|
||||
end
|
||||
Base.:+(a::Zeros, b::AbstractArray) = b
|
||||
Base.:-(a::Zeros, b::AbstractArray) = -b
|
||||
Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a)
|
||||
Base.:*(a::Zeros, b::AbstractArray) = zero(a)
|
||||
|
||||
broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr
|
||||
broadcasted(::typeof(-), arr::AbstractArray, ::Zeros) = arr
|
||||
broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr)
|
||||
# Hook into broadcasting API - to allow using as a regular array
|
||||
Base.BroadcastStyle(::Type{<:Zeros}) = Broadcast.ArrayStyle{Zeros}()
|
||||
Broadcast.broadcastable(xs::Zeros) = xs
|
||||
Base.BroadcastStyle(::Broadcast.ArrayStyle{Zeros}, ::Broadcast.DefaultArrayStyle{N}) where N =
|
||||
Broadcast.ArrayStyle{Zeros}()
|
||||
|
||||
function Base.similar(bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}, ::Type{T}) where T
|
||||
similar(Array{T}, axes(bc))
|
||||
end
|
||||
|
||||
Base.copy(xs::Zeros{T,N}) where {T,N} = Zeros(T, size(xs)...)
|
||||
|
||||
isZeros(x::Zeros) = true
|
||||
isZeros(x) = false
|
||||
|
||||
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}})
|
||||
bc = Broadcast.flatten(bc)
|
||||
|
||||
i = isZeros(first(bc.args)) ? 2 : 1 # findfirst(!isZeros, bc.args)
|
||||
dest .= bc.args[i]
|
||||
end
|
||||
|
||||
"""
|
||||
@jit ...
|
||||
|
Loading…
Reference in New Issue
Block a user