replace weight with filter

This commit is contained in:
Dhairya Gandhi 2019-10-08 20:26:09 +05:30
parent 49ea43e711
commit c85bad4427

View File

@ -59,7 +59,7 @@ function Conv(w::AbstractArray{T,N}, b::Union{Nothing, Zeros, AbstractVector{T}}
end end
""" """
convweight(filter::Tuple, in=>out) convfilter(filter::Tuple, in=>out)
Constructs a standard convolutional weight matrix with given `filter` and Constructs a standard convolutional weight matrix with given `filter` and
channels from `in` to `out`. channels from `in` to `out`.
@ -67,14 +67,14 @@ channels from `in` to `out`.
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
distribution. distribution.
See also: [`depthwiseconvweight`](@ref) See also: [`depthwiseconvfilter`](@ref)
""" """
convweight(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., ch...) init = glorot_uniform) where N = init(filter..., ch...)
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, ch, init = init), bias = zeros(ch[2])) where N weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
Conv(weight, bias, σ, Conv(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@ -152,7 +152,7 @@ end
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, reverse(ch), init = init), bias = zeros(ch[2])) where N weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
ConvTranspose(weight, bias, σ, ConvTranspose(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
@ -243,7 +243,7 @@ function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Nothing, Zeros, AbstractV
end end
""" """
depthwiseconvweight(filter::Tuple, in=>out) depthwiseconvfilter(filter::Tuple, in=>out)
Constructs a depthwise convolutional weight array defined by `filter` and channels Constructs a depthwise convolutional weight array defined by `filter` and channels
from `in` to `out`. from `in` to `out`.
@ -251,14 +251,14 @@ from `in` to `out`.
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
distribution. distribution.
See also: [`convweight`](@ref) See also: [`convfilter`](@ref)
""" """
depthwiseconvweight(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1]) init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = depthwiseconvweight(k, ch, init = init), bias = zeros(ch[2])) where N weight = depthwiseconvfilter(k, ch, init = init), bias = zeros(ch[2])) where N
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels" @assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
return DepthwiseConv( return DepthwiseConv(
@ -350,7 +350,7 @@ end
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, ch, init = init), bias = zeros(ch[2])) where N weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
CrossCor(weight, bias, σ, CrossCor(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)