diff --git a/src/Flux.jl b/src/Flux.jl index 5799fe42..90dcb630 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -27,6 +27,7 @@ using CuArrays const use_cuda = Ref(false) include("utils.jl") +include("zeros.jl") include("onehot.jl") include("functor.jl") diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f8830fee..1c7282b5 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -30,27 +30,36 @@ function calc_padding(::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T} end """ - Conv(size, in => out, σ = identity; init = glorot_uniform, + Conv(filter, in => out, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1) -Standard convolutional layer. `size` should be a tuple like `(2, 2)`. + filter = (2,2) + in = 1 + out = 16 + Conv((2, 2), 1=>16, relu) + +Standard convolutional layer. `filter` should be a tuple like `(2, 2)`. `in` and `out` specify the number of input and output channels respectively. Data should be stored in WHCN order (width, height, # channels, batch size). In other words, a 100×100 RGB image would be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array. +Accepts keyword arguments `weight` and `bias` to set the corresponding fields. +Setting `bias` to `Flux.Zeros()` will switch bias off for the layer. + +Takes the keyword arguments `pad`, `stride` and `dilation`. Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride. # Examples -Apply a `Conv` layer to a 1-channel input using a 2×2 window size, giving us a +Apply a `Conv` layer to a 1-channel input using a 2×2 window filter size, giving us a 16-channel output. Output is activated with ReLU. ```julia -size = (2,2) +filter = (2,2) in = 1 out = 16 -Conv(size, in => out, relu) +Conv(filter, in => out, relu) ``` """ struct Conv{N,M,F,A,V} @@ -62,7 +71,28 @@ struct Conv{N,M,F,A,V} dilation::NTuple{N,Int} end -function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; +""" + Conv(weight::AbstractArray, bias::AbstractArray) + Conv(weight::AbstractArray, bias::AbstractArray, activation) + +Constructs the convolutional layer with user defined weight and bias arrays. + +Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer. + +Takes the keyword arguments `pad`, `stride` and `dilation`. + +There is also a keyword-only constuctor available for all convoultional +layers. + +```julia +weight = rand(Float32, 3, 3, 5) +bias = zeros(Float32, 5) +Conv(weight = weight, + bias = bias, + σ = sigmoid) +``` +""" +function Conv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) @@ -70,10 +100,32 @@ function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; return Conv(σ, w, b, stride, pad, dilation) end -Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = - Conv(init(k..., ch...), zeros(ch[2]), σ, - stride = stride, pad = pad, dilation = dilation) +function Conv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}}, + activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} + Conv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) +end + +""" + convfilter(filter::Tuple, in=>out) + +Constructs a standard convolutional weight matrix with given `filter` and +channels from `in` to `out`. + +Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling +distribution. + +See also: [`depthwiseconvfilter`](@ref) +""" +convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; + init = glorot_uniform) where N = init(filter..., ch...) + +function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = glorot_uniform, stride = 1, pad = 0, dilation = 1, + weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N + + Conv(weight, bias, σ, + stride = stride, pad = pad, dilation = dilation) +end @functor Conv @@ -114,16 +166,22 @@ outdims(l::Conv, isize) = output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) """ - ConvTranspose(size, in => out, σ = identity; init = glorot_uniform, + ConvTranspose(filter, in=>out) + ConvTranspose(filter, in=>out, activation) + ConvTranspose(filter, in => out, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1) -Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`. +Standard convolutional transpose layer. `filter` should be a tuple like `(2, 2)`. `in` and `out` specify the number of input and output channels respectively. Data should be stored in WHCN order (width, height, # channels, batch size). In other words, a 100×100 RGB image would be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array. +Accepts keyword arguments `weight` and `bias` to set the corresponding fields. +Setting `bias` to `Flux.Zeros()` will switch bias off for the layer. + +Takes the keyword arguments `pad`, `stride` and `dilation`. Use `pad=SamePad()` to apply padding so that outputsize == stride * inputsize - stride + 1. """ struct ConvTranspose{N,M,F,A,V} @@ -135,18 +193,39 @@ struct ConvTranspose{N,M,F,A,V} dilation::NTuple{N,Int} end -function ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0, dilation = 1) where {T,N} +""" + ConvTranspose(weight::AbstractArray, bias::AbstractArray) + ConvTranspose(weight::AbstractArray, bias::AbstractArray, activation) + +Constructs the convolutional transpose layer with user defined weight and bias arrays. +forward pass. + +Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer. + +Takes the keyword arguments `pad`, `stride` and `dilation`. + +For keyword-only constuctor, see also [`Conv`](@ref) +""" +function ConvTranspose(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(pad, size(w)[1:N-2], dilation, stride) return ConvTranspose(σ, w, b, stride, pad, dilation) end -ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = -ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ, +function ConvTranspose(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}}, + activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} + ConvTranspose(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) +end + +function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = glorot_uniform, stride = 1, pad = 0, dilation = 1, + weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N + + ConvTranspose(weight, bias, σ, stride = stride, pad = pad, dilation = dilation) +end @functor ConvTranspose @@ -158,9 +237,9 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) batch_size = size(x)[end] # Create DenseConvDims() that looks like the corresponding conv() return DenseConvDims((I..., C_in, batch_size), size(c.weight); - stride=c.stride, - padding=c.pad, - dilation=c.dilation, + stride=c.stride, + padding=c.pad, + dilation=c.dilation, ) end @@ -171,7 +250,7 @@ function (c::ConvTranspose)(x::AbstractArray) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) cdims = conv_transpose_dims(c, x) - return σ.(∇conv_data(x, c.weight, cdims) .+ b) + σ.(∇conv_data(x, c.weight, cdims) .+ b) end function Base.show(io::IO, l::ConvTranspose) @@ -190,10 +269,12 @@ end outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad) """ - DepthwiseConv(size, in => out, σ = identity; init = glorot_uniform, + DepthwiseConv(filter::Tuple, in=>out) + DepthwiseConv(filter::Tuple, in=>out, activation) + DepthwiseConv(filter, in => out, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1) -Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`. +Depthwise convolutional layer. `filter` should be a tuple like `(2, 2)`. `in` and `out` specify the number of input and output channels respectively. Note that `out` must be an integer multiple of `in`. @@ -201,6 +282,10 @@ Data should be stored in WHCN order (width, height, # channels, batch size). In other words, a 100×100 RGB image would be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array. +Accepts keyword arguments `weight` and `bias` to set the corresponding fields. +Setting `bias` to `Flux.Zeros()` will switch bias off for the layer. + +Takes the keyword arguments `pad`, `stride` and `dilation`. Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride. """ struct DepthwiseConv{N,M,F,A,V} @@ -212,20 +297,54 @@ struct DepthwiseConv{N,M,F,A,V} dilation::NTuple{N,Int} end -function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0, dilation = 1) where {T,N} +""" + DepthwiseConv(weight::AbstractArray, bias::AbstractArray) + DepthwiseConv(weight::AbstractArray, bias::AbstractArray, activation) + +Constructs the `DepthwiseConv` layer with user defined weight and bias arrays. +forward pass. + +Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer. + +Takes the keyword arguments `pad`, `stride` and `dilation`. + +For keyword-only constuctor, see also [`Conv`](@ref) +""" +function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(pad, size(w)[1:N-2], dilation, stride) return DepthwiseConv(σ, w, b, stride, pad, dilation) end +function DepthwiseConv(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}}, + activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} + DepthwiseConv(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) +end + +""" + depthwiseconvfilter(filter::Tuple, in=>out) + +Constructs a depthwise convolutional weight array defined by `filter` and channels +from `in` to `out`. + +Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling +distribution. + +See also: [`convfilter`](@ref) +""" +depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; + init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1]) + function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N + init = glorot_uniform, stride = 1, pad = 0, dilation = 1, + weight = depthwiseconvfilter(k, ch, init = init), bias = zeros(ch[2])) where N @assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels" + return DepthwiseConv( - init(k..., div(ch[2], ch[1]), ch[1]), - zeros(ch[2]), + weight, + bias, σ; stride = stride, pad = pad, @@ -258,24 +377,30 @@ outdims(l::DepthwiseConv, isize) = output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) """ - CrossCor(size, in => out, σ = identity; init = glorot_uniform, + CrossCor(filter, in=>out) + CrossCor(filter, in=>out, activation) + CrossCor(filter, in => out, σ = identity; init = glorot_uniform, stride = 1, pad = 0, dilation = 1) -Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`. +Standard cross convolutional layer. `filter` should be a tuple like `(2, 2)`. `in` and `out` specify the number of input and output channels respectively. Data should be stored in WHCN order (width, height, # channels, batch size). In other words, a 100×100 RGB image would be a `100×100×3×1` array, and a batch of 50 would be a `100×100×3×50` array. +Accepts keyword arguments `weight` and `bias` to set the corresponding fields. +Setting `bias` to `Flux.Zeros()` will switch bias off for the layer. + +Takes the keyword arguments `pad`, `stride` and `dilation`. Use `pad=SamePad()` to apply padding so that outputsize == inputsize / stride. # Examples -Apply a `CrossCor` layer to a 1-channel input using a 2×2 window size, giving us a +Apply a `CrossCor` layer to a 1-channel input using a 2×2 window filter size, giving us a 16-channel output. Output is activated with ReLU. ```julia -size = (2,2) +filter = (2,2) in = 1 out = 16 CrossCor((2, 2), 1=>16, relu) @@ -290,18 +415,39 @@ struct CrossCor{N,M,F,A,V} dilation::NTuple{N,Int} end -function CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0, dilation = 1) where {T,N} +""" + CrossCor(weight::AbstractArray, bias::AbstractArray) + CrossCor(weight::AbstractArray, bias::AbstractArray, activation) + +Constructs the standard cross convolutional layer with user defined weight and bias +arrays. + +Setting `bias` to `Flux.Zeros()` would switch `bias` off for the layer. + +Takes the keyword arguments `pad`, `stride` and `dilation`. + +For keyword-only constuctor, see also [`Conv`](@ref) +""" +function CrossCor(w::AbstractArray{T,N}, b::Union{Zeros, AbstractVector{T}}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(pad, size(w)[1:N-2], dilation, stride) return CrossCor(σ, w, b, stride, pad, dilation) end -CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N = - CrossCor(init(k..., ch...), zeros(ch[2]), σ, +function CrossCor(;weight::AbstractArray{T,N}, bias::Union{Zeros, AbstractVector{T}}, + activation = identity, stride = 1, pad = 0, dilation = 1) where {T,N} + CrossCor(weight, bias, activation, stride = stride, pad = pad, dilation = dilation) +end + +function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = glorot_uniform, stride = 1, pad = 0, dilation = 1, + weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N + + CrossCor(weight, bias, σ, stride = stride, pad = pad, dilation = dilation) +end @functor CrossCor diff --git a/src/zeros.jl b/src/zeros.jl new file mode 100644 index 00000000..1aec7b02 --- /dev/null +++ b/src/zeros.jl @@ -0,0 +1,106 @@ +import Base: +, -, *, reshape, size +import Base.Broadcast: broadcasted, Broadcasted, BroadcastStyle + +""" + Zeros() + Zeros(size...) + Zeros(Type, size...) + +Acts as a stand-in for an array of zeros that can be +used during training which is ignored by the optimisers. + +Useful to turn bias off for a forward pass of a layer. + +## Examples + +```julia +julia> Flux.Zeros(3,3) +3×3 Flux.Zeros{Bool,2}: + false false false + false false false + false false false + +julia> Flux.Zeros(Float32, 3,3) +3×3 Flux.Zeros{Float32,2}: + 0.0 0.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 + +julia> rand(3,3) .+ Flux.Zeros() +3×3 Array{Float64,2}: + 0.198739 0.490459 0.785386 + 0.779074 0.39986 0.66383 + 0.854981 0.447292 0.314497 + +julia> bias_less_conv = Conv((2,2), 1=>3, bias = Flux.Zeros()) +Conv((2, 2), 1=>3) +``` +""" +struct Zeros{T,N} <: AbstractArray{T,N} + size::Tuple +end + +Zeros(::Type{T}, sz...) where T = Zeros{T,length(sz)}(sz) +Zeros(sz::Integer...) = Zeros(Bool, sz...) + +Base.size(xs::Zeros) = xs.size +Base.axes(xs::Zeros) = Base.OneTo.(size(xs)) + +Base.IndexStyle(::Type{<:Zeros}) = IndexLinear() + +Base.getindex(xs::Zeros{T,N}, I::Int) where {T,N} = zero(T) +Base.getindex(xs::Zeros{T,N}, inds::Union{Base.OneTo, Base.UnitRange}) where {T,N} = + Zeros(T, length(inds)) + +Base.collect(xs::Zeros{T,N}) where {T,N} = fill(zero(T), size(xs)) + +@adjoint reshape(xs::Zeros{T}, dims...) where T = + reshape(xs, dims...), _ -> nothing + +# Define basic ops +for f in (:+, :-) + @eval @inline function $f(a::Union{AbstractArray{<:Number}, Zeros}, b::Zeros) + @assert size(a) == size(b) throw(DimensionMismatch("dimensions must match")) + a + end +end + ++(a::Zeros, b::AbstractArray) = b + a +-(a::Zeros, b::AbstractArray) = -b + a + +Base.copy(xs::Zeros{T,N}) where {T,N} = xs + +# Define broadcasting behaviour +for op in (:+, :-) + @eval function broadcasted(::typeof($op), a::AbstractArray, b::Zeros) + bs = Broadcast.broadcast_shape(size(a), size(b)) + size(a) == bs && return a + sz = similar(a, bs) + sz .= a + end +end + +broadcasted(::typeof(+), a::Zeros, b::AbstractArray) = broadcasted(+, b, a) +broadcasted(::typeof(-), a::Zeros, b::AbstractArray) = broadcasted(+, -b, a) + +function broadcasted(::typeof(*), a::AbstractArray, b::Zeros) + Zeros(Broadcast.broadcast_shape(size(a), size(b))...) +end + +broadcasted(::typeof(*), a::Zeros, b::AbstractArray) = broadcasted(*, b, a) + +for op in (:+, :-, :*) + @eval broadcasted(::typeof($op), a::Zeros, b::Zeros) = Zeros(Broadcast.broadcast_shape(size(a), size(b))...) +end + +# Some opportunities to avoid scalar indexing, intermediaries +# Since it replicates a little of what we expect Base to do, +# it should be possible to remove in the future, but for now, +# these help with performance. +broadcasted(::typeof(+), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(+), a::Zeros{T,0}, b::AbstractArray) where T = b +broadcasted(::typeof(-), a::AbstractArray, b::Zeros{T,0}) where T = a +broadcasted(::typeof(-), a::Zeros{T,0}, b::AbstractArray) where T = -b +broadcasted(::typeof(*), a::AbstractArray, b::Zeros{T,0}) where T = zero(a) +broadcasted(::typeof(*), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) +broadcasted(::typeof(/), a::Zeros{T,0}, b::AbstractArray) where T = zero(b) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 97355b18..8c825bfd 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -25,6 +25,35 @@ end Dense(288, 10), softmax) @test size(m(r)) == (10, 5) + + # Test bias switch + bias = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3)) + ip = zeros(Float32, 28,28,1,1) + + op = bias(ip) + @test sum(op) == prod(size(op)) + + bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) + op = bias(ip) + @test sum(op) === 0.f0 + gs = gradient(() -> sum(bias(ip)), Flux.params(bias)) + @test gs[bias.bias] == nothing + + # Train w/o bias and make sure no convergence happens + # when only bias can be converged + bias = Conv((2, 2), 1=>3, bias = Flux.Zeros()); + ip = zeros(Float32, 28,28,1,1) + op = zeros(Float32, 27,27,3,1) .+ 2.f0 + opt = Descent() + + for _ = 1:10^3 + gs = gradient(params(bias)) do + Flux.mse(bias(ip), op) + end + Flux.Optimise.update!(opt, params(bias), gs) + end + + @test Flux.mse(bias(ip), op) ≈ 4.f0 end @testset "asymmetric padding" begin