inferrable
This commit is contained in:
parent
3b575930ca
commit
7726a5b605
@ -1,7 +1,9 @@
|
||||
using NNlib: conv
|
||||
|
||||
expand(::Type{Val{N}}, i::Integer) where N = ntuple(_ -> i, Val{N})
|
||||
expand(::Type{Val{N}}, i::NTuple{N, Integer}) where N = i
|
||||
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
||||
"""
|
||||
Conv(size, in=>out)
|
||||
@ -24,9 +26,9 @@ struct Conv{N,F,A,V}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where T =
|
||||
Conv(σ, w, b, expand.(Val{ndims(w)-2}, (stride, pad, dilation))...)
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
|
Loading…
Reference in New Issue
Block a user