replace weight with filter

This commit is contained in:
Dhairya Gandhi 2019-10-08 20:26:09 +05:30
parent 49ea43e711
commit c85bad4427

View File

@ -59,7 +59,7 @@ function Conv(w::AbstractArray{T,N}, b::Union{Nothing, Zeros, AbstractVector{T}}
end
"""
convweight(filter::Tuple, in=>out)
convfilter(filter::Tuple, in=>out)
Constructs a standard convolutional weight matrix with given `filter` and
channels from `in` to `out`.
@ -67,14 +67,14 @@ channels from `in` to `out`.
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
distribution.
See also: [`depthwiseconvweight`](@ref)
See also: [`depthwiseconvfilter`](@ref)
"""
convweight(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., ch...)
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, ch, init = init), bias = zeros(ch[2])) where N
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
Conv(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
@ -152,7 +152,7 @@ end
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
ConvTranspose(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
@ -243,7 +243,7 @@ function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Nothing, Zeros, AbstractV
end
"""
depthwiseconvweight(filter::Tuple, in=>out)
depthwiseconvfilter(filter::Tuple, in=>out)
Constructs a depthwise convolutional weight array defined by `filter` and channels
from `in` to `out`.
@ -251,14 +251,14 @@ from `in` to `out`.
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
distribution.
See also: [`convweight`](@ref)
See also: [`convfilter`](@ref)
"""
depthwiseconvweight(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = depthwiseconvweight(k, ch, init = init), bias = zeros(ch[2])) where N
weight = depthwiseconvfilter(k, ch, init = init), bias = zeros(ch[2])) where N
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
return DepthwiseConv(
@ -350,7 +350,7 @@ end
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, ch, init = init), bias = zeros(ch[2])) where N
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
CrossCor(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)