add bias and weight kwarg

This commit is contained in:
Dhairya Gandhi 2019-10-08 17:18:19 +05:30
parent f3904b4e04
commit 040697fb2b

View File

@ -3,16 +3,16 @@ using NNlib: conv, ∇conv_data, depthwiseconv
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
"""
Conv(size, in=>out)
Conv(size, in=>out, relu)
Conv(filter::Tuple, in=>out)
Conv(filter::Tuple, in=>out, activation)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
Standard convolutional layer. `filter` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
filter = (2,2)
in = 1
out = 16
Conv((2, 2), 1=>16, relu)
@ -34,7 +34,7 @@ end
"""
Conv(weight::AbstractArray, bias::AbstractArray)
Conv(weight::AbstractArray, bias::AbstractArray, relu)
Conv(weight::AbstractArray, bias::AbstractArray, activation)
Constructs the convolutional layer with user defined weight and bias arrays.
All other behaviours of the Conv layer apply with regard to data order and
@ -42,21 +42,32 @@ forward pass.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
function Conv(w::AbstractArray{T,N}, b::Union{Number, AbstractVector{T}}, σ = identity;
function Conv(w::AbstractArray{T,N}, b::Union{Nothing, ZeroType, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
b = b isa Nothing ? ZeroType((size(w, ndims(w)), )) : b
return Conv(σ, w, b, stride, pad, dilation)
end
convweight(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(k..., ch...)
const convbias = zeros
"""
convweight(filter::Tuple, in=>out)
Constructs a standard convolutional weight matrix with given `filter` and
channels from `in` to `out`.
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
distribution.
See also: [`depthwiseconvweight`](@ref)
"""
convweight(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., ch...)
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, ch, init = init), bias = convbias(ch[2])) where N
weight = convweight(k, ch, init = init), bias = zeros(ch[2])) where N
Conv(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
@ -86,10 +97,10 @@ end
a(T.(x))
"""
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
ConvTranspose(filter::Tuple, in=>out)
ConvTranspose(filter::Tuple, in=>out, relu)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
Standard convolutional transpose layer. `filter` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
@ -106,17 +117,28 @@ struct ConvTranspose{N,M,F,A,V}
dilation::NTuple{N,Int}
end
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Number, AbstractVector{T}}, σ = identity;
"""
ConvTranspose(weight::AbstractArray, bias::AbstractArray)
ConvTranspose(weight::AbstractArray, bias::AbstractArray, activation)
Constructs the convolutional transpose layer with user defined weight and bias arrays.
All other behaviours of the ConvTranspose layer apply with regard to data order and
forward pass.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Nothing, ZeroType, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
b = b isa Nothing ? ZeroType((size(w, ndims(w)), )) : b
return ConvTranspose(σ, w, b, stride, pad, dilation)
end
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, reverse(ch), init = init), bias = convbias(ch[2])) where N
weight = convweight(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
ConvTranspose(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
@ -157,11 +179,12 @@ end
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
DepthwiseConv(size, in=>out)
DepthwiseConv(size, in=>out, relu)
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
"""
DepthwiseConv(filter::Tuple, in=>out)
DepthwiseConv(filter::Tuple, in=>out, relu)
Depthwise convolutional layer. `filter` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Note that `out` must be an integer multiple of `in`.
@ -179,21 +202,44 @@ struct DepthwiseConv{N,M,F,A,V}
dilation::NTuple{N,Int}
end
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Number, AbstractVector{T}}, σ = identity;
"""
DepthwiseConv(weight::AbstractArray, bias::AbstractArray)
DepthwiseConv(weight::AbstractArray, bias::AbstractArray, activation)
Constructs the `DepthwiseConv` layer with user defined weight and bias arrays.
All other behaviours of the `DepthwiseConv` layer apply with regard to data order and
forward pass.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Nothing, ZeroType, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
b = b isa Nothing ? ZeroType((size(w, ndims(w)), )) : b
return DepthwiseConv(σ, w, b, stride, pad, dilation)
end
depthwiseconvweight(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(k..., div(ch[2], ch[1]), ch[1])
"""
depthwiseconvweight(filter::Tuple, in=>out)
Constructs a depthwise convolutional weight array defined by `filter` and channels
from `in` to `out`.
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
distribution.
See also: [`convweight`](@ref)
"""
depthwiseconvweight(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = depthwiseconvweight(k, ch, init = init), bias = convbias(ch[2])) where N
weight = depthwiseconvweight(k, ch, init = init), bias = zeros(ch[2])) where N
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
return DepthwiseConv(
weight,
bias,
@ -255,17 +301,29 @@ struct CrossCor{N,M,F,A,V}
dilation::NTuple{N,Int}
end
function CrossCor(w::AbstractArray{T,N}, b::Union{Number, AbstractVector{T}}, σ = identity;
"""
CrossCor(weight::AbstractArray, bias::AbstractArray)
CrossCor(weight::AbstractArray, bias::AbstractArray, activation)
Constructs the standard cross convolutional layer with user defined weight and bias
arrays. All other behaviours of the CrossCor layer apply with regard to data order and
forward pass.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
function CrossCor(w::AbstractArray{T,N}, b::Union{Nothing, ZeroType, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
b = b isa Nothing ? ZeroType((size(w, ndims(w)), )) : b
return CrossCor(σ, w, b, stride, pad, dilation)
end
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, ch, init = init), bias = convbias(ch[2])) where N
weight = convweight(k, ch, init = init), bias = zeros(ch[2])) where N
CrossCor(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
end