make Zeros a dimensionlesss number
This commit is contained in:
parent
c85bad4427
commit
4a183aeaf0
44
src/utils.jl
44
src/utils.jl
@ -139,30 +139,40 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
import Base: +, reshape, size
|
import Base: +, -, reshape, size
|
||||||
|
import Base.Broadcast: broadcasted
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Zeros()
|
Zeros()
|
||||||
Zeros(T, a::Union{Colon, Int}...)
|
|
||||||
|
|
||||||
Acts as a stand-in for an array of zeros that can be used during training which is
|
Acts as a stand-in for an array of zeros that can be
|
||||||
ignored by the optimisers.
|
used during training which is ignored by the optimisers.
|
||||||
|
|
||||||
|
Used to turn bias off for a forward pass of a layer.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> rand(3,3) .+ Flux.Zeros()
|
||||||
|
3×3 Array{Float64,2}:
|
||||||
|
0.198739 0.490459 0.785386
|
||||||
|
0.779074 0.39986 0.66383
|
||||||
|
0.854981 0.447292 0.314497
|
||||||
|
|
||||||
|
julia> bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
|
||||||
|
Conv((2, 2), 1=>3)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
struct Zeros{T} <: Number
|
struct Zeros <: Number end
|
||||||
size::Tuple
|
for f in (:+, :-)
|
||||||
|
@eval $f(a::Union{Number, Zeros}, b::Zeros) = a
|
||||||
end
|
end
|
||||||
|
Base.:*(a::Union{Number, Zeros}, b::Zeros) = zero(a)
|
||||||
|
|
||||||
Zeros(::Type{T}, sz...) where T = Zeros{T}(sz)
|
broadcasted(::typeof(+), arr::AbstractArray, ::Zeros) = arr
|
||||||
Zeros(sz::Union{Integer, Colon}...) = Zeros(Bool, sz...)
|
broadcasted(::typeof(*), arr::AbstractArray, ::Zeros) = zero(arr)
|
||||||
|
Base.reshape(xs::Zeros, args...) = xs
|
||||||
+(a::Number, ::Zeros) = a
|
@adjoint reshape(xs::Zeros, dims...) = reshape(xs, dims...), _ -> nothing
|
||||||
+(::Zeros, a::Number) = a
|
|
||||||
|
|
||||||
size(xs::Zeros) = xs.size
|
|
||||||
reshape(z::Zeros{T}, args...) where T = Zeros(T, args...)
|
|
||||||
|
|
||||||
@adjoint reshape(xs::Zeros{T}, dims...) where T =
|
|
||||||
Zeros(T, dims...), Δ -> (Zeros(T, size(xs)...), map(_ -> nothing, dims)...)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@jit ...
|
@jit ...
|
||||||
|
Loading…
Reference in New Issue
Block a user