hook into bcasting
This commit is contained in:
parent
7c90fb469d
commit
a4a987f0b0
66
src/utils.jl
66
src/utils.jl
@ -139,27 +139,45 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
import Base: +, -, reshape, size
|
import Base: +, -, *, reshape, size
|
||||||
import Base.Broadcast: broadcasted
|
import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Zeros()
|
Zeros()
|
||||||
|
Zeros(size...)
|
||||||
|
Zeros(Type, size...)
|
||||||
|
|
||||||
Acts as a stand-in for an array of zeros that can be
|
Acts as a stand-in for an array of zeros that can be
|
||||||
used during training which is ignored by the optimisers.
|
used during training which is ignored by the optimisers.
|
||||||
|
|
||||||
Used to turn bias off for a forward pass of a layer.
|
Useful to turn bias off for a forward pass of a layer.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Zeros acts a scalar while broadcasting, so does not
|
||||||
|
expand dims. Checks for shape compatibility by default.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
|
julia> Flux.Zeros(3,3)
|
||||||
|
3×3 Flux.Zeros{Bool,2}:
|
||||||
|
false false false
|
||||||
|
false false false
|
||||||
|
false false false
|
||||||
|
|
||||||
|
julia> Flux.Zeros(Float32, 3,3)
|
||||||
|
3×3 Flux.Zeros{Float32,2}:
|
||||||
|
0.0 0.0 0.0
|
||||||
|
0.0 0.0 0.0
|
||||||
|
0.0 0.0 0.0
|
||||||
|
|
||||||
julia> rand(3,3) .+ Flux.Zeros()
|
julia> rand(3,3) .+ Flux.Zeros()
|
||||||
3×3 Array{Float64,2}:
|
3×3 Array{Float64,2}:
|
||||||
0.198739 0.490459 0.785386
|
0.198739 0.490459 0.785386
|
||||||
0.779074 0.39986 0.66383
|
0.779074 0.39986 0.66383
|
||||||
0.854981 0.447292 0.314497
|
0.854981 0.447292 0.314497
|
||||||
|
|
||||||
julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
|
julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros())
|
||||||
Conv((2, 2), 1=>3)
|
Conv((2, 2), 1=>3)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@ -170,14 +188,15 @@ end
|
|||||||
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
|
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
|
||||||
Zeros(sz::Integer...) = Zeros(Bool, sz...)
|
Zeros(sz::Integer...) = Zeros(Bool, sz...)
|
||||||
|
|
||||||
+(a::Union{AbstractVecOrMat, Number}, ::Zeros) = a
|
|
||||||
|
|
||||||
Base.size(xs::Zeros) = xs.size
|
Base.size(xs::Zeros) = xs.size
|
||||||
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
|
|
||||||
|
|
||||||
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
|
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
|
||||||
|
|
||||||
Base.getindex(xs::Zeros{T,N}, i::Int) where {T,N} = zero(T)
|
Base.IndexStyle(::Type{<:Zeros}) = IndexCartesian()
|
||||||
|
|
||||||
|
Base.getindex(xs::Zeros{T,N}, I::Vararg{Int, N}) where {T,N} = zero(T)
|
||||||
|
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
|
||||||
|
Zeros(T, inds.stop)
|
||||||
|
|
||||||
Base.setindex(xs::Zeros, args...) =
|
Base.setindex(xs::Zeros, args...) =
|
||||||
error("setindex disallowed on Zeros Array")
|
error("setindex disallowed on Zeros Array")
|
||||||
Base.setindex!(xs::Zeros, args...) =
|
Base.setindex!(xs::Zeros, args...) =
|
||||||
@ -185,17 +204,40 @@ Base.setindex!(xs::Zeros, args...) =
|
|||||||
|
|
||||||
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
|
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
|
||||||
|
|
||||||
|
# Ignore during backwards pass
|
||||||
@adjoint reshape(xs::Zeros{T}, dims...) where T =
|
@adjoint reshape(xs::Zeros{T}, dims...) where T =
|
||||||
reshape(xs, dims...), _ -> nothing
|
reshape(xs, dims...), _ -> nothing
|
||||||
|
|
||||||
|
# Define basic ops
|
||||||
for f in (:+, :-)
|
for f in (:+, :-)
|
||||||
@eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a
|
@eval $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = a
|
||||||
end
|
end
|
||||||
|
Base.:+(a::Zeros, b::AbstractArray) = b
|
||||||
|
Base.:-(a::Zeros, b::AbstractArray) = -b
|
||||||
Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a)
|
Base.:*(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) = zero(a)
|
||||||
|
Base.:*(a::Zeros, b::AbstractArray) = zero(a)
|
||||||
|
|
||||||
broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr
|
# Hook into broadcasting API - to allow using as a regular array
|
||||||
broadcasted(::typeof(-), arr::AbstractArray, ::Zeros) = arr
|
Base.BroadcastStyle(::Type{<:Zeros}) = Broadcast.ArrayStyle{Zeros}()
|
||||||
broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr)
|
Broadcast.broadcastable(xs::Zeros) = xs
|
||||||
|
Base.BroadcastStyle(::Broadcast.ArrayStyle{Zeros}, ::Broadcast.DefaultArrayStyle{N}) where N =
|
||||||
|
Broadcast.ArrayStyle{Zeros}()
|
||||||
|
|
||||||
|
function Base.similar(bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}}, ::Type{T}) where T
|
||||||
|
similar(Array{T}, axes(bc))
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.copy(xs::Zeros{T,N}) where {T,N} = Zeros(T, size(xs)...)
|
||||||
|
|
||||||
|
isZeros(x::Zeros) = true
|
||||||
|
isZeros(x) = false
|
||||||
|
|
||||||
|
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{Broadcast.ArrayStyle{Flux.Zeros}})
|
||||||
|
bc = Broadcast.flatten(bc)
|
||||||
|
|
||||||
|
i = isZeros(first(bc.args)) ? 2 : 1 # findfirst(!isZeros, bc.args)
|
||||||
|
dest .= bc.args[i]
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@jit ...
|
@jit ...
|
||||||
|
Loading…
Reference in New Issue
Block a user