replace weight with filter
This commit is contained in:
parent
49ea43e711
commit
c85bad4427
@ -59,7 +59,7 @@ function Conv(w::AbstractArray{T,N}, b::Union{Nothing, Zeros, AbstractVector{T}}
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
convweight(filter::Tuple, in=>out)
|
convfilter(filter::Tuple, in=>out)
|
||||||
|
|
||||||
Constructs a standard convolutional weight matrix with given `filter` and
|
Constructs a standard convolutional weight matrix with given `filter` and
|
||||||
channels from `in` to `out`.
|
channels from `in` to `out`.
|
||||||
@ -67,14 +67,14 @@ channels from `in` to `out`.
|
|||||||
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
|
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
|
||||||
distribution.
|
distribution.
|
||||||
|
|
||||||
See also: [`depthwiseconvweight`](@ref)
|
See also: [`depthwiseconvfilter`](@ref)
|
||||||
"""
|
"""
|
||||||
convweight(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
|
convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
|
||||||
init = glorot_uniform) where N = init(filter..., ch...)
|
init = glorot_uniform) where N = init(filter..., ch...)
|
||||||
|
|
||||||
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||||
weight = convweight(k, ch, init = init), bias = zeros(ch[2])) where N
|
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
|
||||||
|
|
||||||
Conv(weight, bias, σ,
|
Conv(weight, bias, σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
@ -152,7 +152,7 @@ end
|
|||||||
|
|
||||||
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||||
weight = convweight(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
|
weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
|
||||||
|
|
||||||
ConvTranspose(weight, bias, σ,
|
ConvTranspose(weight, bias, σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
@ -243,7 +243,7 @@ function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Nothing, Zeros, AbstractV
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
depthwiseconvweight(filter::Tuple, in=>out)
|
depthwiseconvfilter(filter::Tuple, in=>out)
|
||||||
|
|
||||||
Constructs a depthwise convolutional weight array defined by `filter` and channels
|
Constructs a depthwise convolutional weight array defined by `filter` and channels
|
||||||
from `in` to `out`.
|
from `in` to `out`.
|
||||||
@ -251,14 +251,14 @@ from `in` to `out`.
|
|||||||
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
|
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
|
||||||
distribution.
|
distribution.
|
||||||
|
|
||||||
See also: [`convweight`](@ref)
|
See also: [`convfilter`](@ref)
|
||||||
"""
|
"""
|
||||||
depthwiseconvweight(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
|
depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
|
||||||
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
|
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
|
||||||
|
|
||||||
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||||
weight = depthwiseconvweight(k, ch, init = init), bias = zeros(ch[2])) where N
|
weight = depthwiseconvfilter(k, ch, init = init), bias = zeros(ch[2])) where N
|
||||||
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
|
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
|
||||||
|
|
||||||
return DepthwiseConv(
|
return DepthwiseConv(
|
||||||
@ -350,7 +350,7 @@ end
|
|||||||
|
|
||||||
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||||
weight = convweight(k, ch, init = init), bias = zeros(ch[2])) where N
|
weight = convfilter(k, ch, init = init), bias = zeros(ch[2])) where N
|
||||||
|
|
||||||
CrossCor(weight, bias, σ,
|
CrossCor(weight, bias, σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
Loading…
Reference in New Issue
Block a user