873: Make bias optional r=MikeInnes a=dhairyagandhi96

Addresses #868 



Co-authored-by: Dhairya Gandhi <dhairya@juliacopmuting.com>
This commit is contained in:
bors[bot] 2020-05-01 13:28:15 +00:00 committed by GitHub
commit 5d9acc7e73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 319 additions and 37 deletions

View File

@ -27,6 +27,7 @@ using CuArrays
const use_cuda = Ref(false)
include("utils.jl")
include("zeros.jl")
include("onehot.jl")
include("functor.jl")

View File

@ -30,27 +30,36 @@ function calc_padding(::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T}
end
"""
Conv(size, in => out, σ = identity; init = glorot_uniform,
Conv(filter, in => out, σ = identity; init = glorot_uniform,
stride = 1, pad = 0, dilation = 1)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
filter = (2,2)
in = 1
out = 16
Conv((2, 2), 1=>16, relu)
Standard convolutional layer. `filter` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
Takes the keyword arguments `pad`, `stride` and `dilation`.
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride.
# Examples
Apply a `Conv` layer to a 1-channel input using a 2×2 window size, giving us a
Apply a `Conv` layer to a 1-channel input using a 2×2 window filter size, giving us a
16-channel output. Output is activated with ReLU.
```julia
size = (2,2)
filter = (2,2)
in = 1
out = 16
Conv(size, in => out, relu)
Conv(filter, in => out, relu)
```
"""
struct Conv{N,M,F,A,V}
@ -62,7 +71,28 @@ struct Conv{N,M,F,A,V}
dilation::NTuple{N,Int}
end
function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
"""
Conv(weight::AbstractArray, bias::AbstractArray)
Conv(weight::AbstractArray, bias::AbstractArray, activation)
Constructs the convolutional layer with user defined weight and bias arrays.
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
Takes the keyword arguments `pad`, `stride` and `dilation`.
There is also a keyword-only constuctor available for all convoultional
layers.
```julia
weight = rand(Float32, 3, 3, 5)
bias = zeros(Float32, 5)
Conv(weight = weight,
bias = bias,
σ = sigmoid)
```
"""
function Conv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
@ -70,10 +100,32 @@ function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
return Conv(σ, w, b, stride, pad, dilation)
end
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
Conv(init(k..., ch...), zeros(ch[2]), σ,
stride = stride, pad = pad, dilation = dilation)
function Conv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
Conv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
end
"""
convfilter(filter::Tuple, in=>out)
Constructs a standard convolutional weight matrix with given `filter` and
channels from `in` to `out`.
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
distribution.
See also: [`depthwiseconvfilter`](@ref)
"""
convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., ch...)
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
Conv(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
end
@functor Conv
@ -114,16 +166,22 @@ outdims(l::Conv, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
"""
ConvTranspose(size, in => out, σ = identity; init = glorot_uniform,
ConvTranspose(filter, in=>out)
ConvTranspose(filter, in=>out, activation)
ConvTranspose(filter, in => out, σ = identity; init = glorot_uniform,
stride = 1, pad = 0, dilation = 1)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
Standard convolutional transpose layer. `filter` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
Takes the keyword arguments `pad`, `stride` and `dilation`.
Use `pad=SamePad()` to apply padding so that outputsize == stride * inputsize - stride + 1.
"""
struct ConvTranspose{N,M,F,A,V}
@ -135,18 +193,39 @@ struct ConvTranspose{N,M,F,A,V}
dilation::NTuple{N,Int}
end
function ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
"""
ConvTranspose(weight::AbstractArray, bias::AbstractArray)
ConvTranspose(weight::AbstractArray, bias::AbstractArray, activation)
Constructs the convolutional transpose layer with user defined weight and bias arrays.
forward pass.
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
Takes the keyword arguments `pad`, `stride` and `dilation`.
For keyword-only constuctor, see also [`Conv`](@ref)
"""
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(pad, size(w)[1:N-2], dilation, stride)
return ConvTranspose(σ, w, b, stride, pad, dilation)
end
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
function ConvTranspose(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
ConvTranspose(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
end
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
ConvTranspose(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
end
@functor ConvTranspose
@ -158,9 +237,9 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
batch_size = size(x)[end]
# Create DenseConvDims() that looks like the corresponding conv()
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
stride=c.stride,
padding=c.pad,
dilation=c.dilation,
stride=c.stride,
padding=c.pad,
dilation=c.dilation,
)
end
@ -171,7 +250,7 @@ function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = conv_transpose_dims(c, x)
return σ.(∇conv_data(x, c.weight, cdims) .+ b)
σ.(∇conv_data(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::ConvTranspose)
@ -190,10 +269,12 @@ end
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)
"""
DepthwiseConv(size, in => out, σ = identity; init = glorot_uniform,
DepthwiseConv(filter::Tuple, in=>out)
DepthwiseConv(filter::Tuple, in=>out, activation)
DepthwiseConv(filter, in => out, σ = identity; init = glorot_uniform,
stride = 1, pad = 0, dilation = 1)
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
Depthwise convolutional layer. `filter` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Note that `out` must be an integer multiple of `in`.
@ -201,6 +282,10 @@ Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
Takes the keyword arguments `pad`, `stride` and `dilation`.
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride.
"""
struct DepthwiseConv{N,M,F,A,V}
@ -212,20 +297,54 @@ struct DepthwiseConv{N,M,F,A,V}
dilation::NTuple{N,Int}
end
function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
"""
DepthwiseConv(weight::AbstractArray, bias::AbstractArray)
DepthwiseConv(weight::AbstractArray, bias::AbstractArray, activation)
Constructs the `DepthwiseConv` layer with user defined weight and bias arrays.
forward pass.
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
Takes the keyword arguments `pad`, `stride` and `dilation`.
For keyword-only constuctor, see also [`Conv`](@ref)
"""
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(pad, size(w)[1:N-2], dilation, stride)
return DepthwiseConv(σ, w, b, stride, pad, dilation)
end
function DepthwiseConv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
DepthwiseConv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
end
"""
depthwiseconvfilter(filter::Tuple, in=>out)
Constructs a depthwise convolutional weight array defined by `filter` and channels
from `in` to `out`.
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
distribution.
See also: [`convfilter`](@ref)
"""
depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = depthwiseconvfilter(k, ch, init = init), bias = zeros(ch[2])) where N
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
return DepthwiseConv(
init(k..., div(ch[2], ch[1]), ch[1]),
zeros(ch[2]),
weight,
bias,
σ;
stride = stride,
pad = pad,
@ -258,24 +377,30 @@ outdims(l::DepthwiseConv, isize) =
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))
"""
CrossCor(size, in => out, σ = identity; init = glorot_uniform,
CrossCor(filter, in=>out)
CrossCor(filter, in=>out, activation)
CrossCor(filter, in => out, σ = identity; init = glorot_uniform,
stride = 1, pad = 0, dilation = 1)
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
Standard cross convolutional layer. `filter` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order (width, height, # channels, batch size).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Accepts keyword arguments `weight` and `bias` to set the corresponding fields.
Setting `bias` to `Flux.Zeros()` will switch bias off for the layer.
Takes the keyword arguments `pad`, `stride` and `dilation`.
Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride.
# Examples
Apply a `CrossCor` layer to a 1-channel input using a 2×2 window size, giving us a
Apply a `CrossCor` layer to a 1-channel input using a 2×2 window filter size, giving us a
16-channel output. Output is activated with ReLU.
```julia
size = (2,2)
filter = (2,2)
in = 1
out = 16
CrossCor((2, 2), 1=>16, relu)
@ -290,18 +415,39 @@ struct CrossCor{N,M,F,A,V}
dilation::NTuple{N,Int}
end
function CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
"""
CrossCor(weight::AbstractArray, bias::AbstractArray)
CrossCor(weight::AbstractArray, bias::AbstractArray, activation)
Constructs the standard cross convolutional layer with user defined weight and bias
arrays.
Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer.
Takes the keyword arguments `pad`, `stride` and `dilation`.
For keyword-only constuctor, see also [`Conv`](@ref)
"""
function CrossCor(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(pad, size(w)[1:N-2], dilation, stride)
return CrossCor(σ, w, b, stride, pad, dilation)
end
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
CrossCor(init(k..., ch...), zeros(ch[2]), σ,
function CrossCor(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}},
activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N}
CrossCor(weight, bias, activation, stride = stride, pad = pad, dilation = dilation)
end
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
CrossCor(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
end
@functor CrossCor

106
src/zeros.jl Normal file
View File

@ -0,0 +1,106 @@
import Base: +, -, *, reshape, size
import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle
"""
Zeros()
Zeros(size...)
Zeros(Type, size...)
Acts as a stand-in for an array of zeros that can be
used during training which is ignored by the optimisers.
Useful to turn bias off for a forward pass of a layer.
## Examples
```julia
julia> Flux.Zeros(3,3)
3×3 Flux.Zeros{Bool,2}:
false false false
false false false
false false false
julia> Flux.Zeros(Float32, 3,3)
3×3 Flux.Zeros{Float32,2}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
julia> rand(3,3) .+ Flux.Zeros()
3×3 Array{Float64,2}:
0.198739 0.490459 0.785386
0.779074 0.39986 0.66383
0.854981 0.447292 0.314497
julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros())
Conv((2, 2), 1=>3)
```
"""
struct Zeros{T,N} <: AbstractArray{T,N}
size::Tuple
end
Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz)
Zeros(sz::Integer...) = Zeros(Bool, sz...)
Base.size(xs::Zeros) = xs.size
Base.axes(xs::Zeros) = Base.OneTo.(size(xs))
Base.IndexStyle(::Type{<:Zeros}) = IndexLinear()
Base.getindex(xs::Zeros{T,N}, I::Int) where {T,N} = zero(T)
Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} =
Zeros(T, length(inds))
Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs))
@adjoint reshape(xs::Zeros{T}, dims...) where T =
reshape(xs, dims...), _ -> nothing
# Define basic ops
for f in (:+, :-)
@eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros)
@assert size(a) == size(b) throw(DimensionMismatch("dimensions must match"))
a
end
end
+(a::Zeros, b::AbstractArray) = b + a
-(a::Zeros, b::AbstractArray) = -b + a
Base.copy(xs::Zeros{T,N}) where {T,N} = xs
# Define broadcasting behaviour
for op in (:+, :-)
@eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros)
bs = Broadcast.broadcast_shape(size(a), size(b))
size(a) == bs && return a
sz = similar(a, bs)
sz .= a
end
end
broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a)
broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a)
function broadcasted(::typeof(*), a::AbstractArray, b::Zeros)
Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
end
broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(*, b, a)
for op in (:+, :-, :*)
@eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...)
end
# Some opportunities to avoid scalar indexing, intermediaries
# Since it replicates a little of what we expect Base to do,
# it should be possible to remove in the future, but for now,
# these help with performance.
broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a
broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b
broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a
broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b
broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a)
broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)
broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b)

View File

@ -25,6 +25,35 @@ end
Dense(288, 10), softmax)
@test size(m(r)) == (10, 5)
# Test bias switch
bias = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3))
ip = zeros(Float32, 28,28,1,1)
op = bias(ip)
@test sum(op) == prod(size(op))
bias = Conv((2,2), 1=>3, bias = Flux.Zeros())
op = bias(ip)
@test sum(op) === 0.f0
gs = gradient(() -> sum(bias(ip)), Flux.params(bias))
@test gs[bias.bias] == nothing
# Train w/o bias and make sure no convergence happens
# when only bias can be converged
bias = Conv((2, 2), 1=>3, bias = Flux.Zeros());
ip = zeros(Float32, 28,28,1,1)
op = zeros(Float32, 27,27,3,1) .+ 2.f0
opt = Descent()
for _ = 1:10^3
gs = gradient(params(bias)) do
Flux.mse(bias(ip), op)
end
Flux.Optimise.update!(opt, params(bias), gs)
end
@test Flux.mse(bias(ip), op) 4.f0
end
@testset "asymmetric padding" begin