inferrable

This commit is contained in:
Mike J Innes 2018-06-26 14:05:07 +01:00
parent 3b575930ca
commit 7726a5b605

View File

@ -1,7 +1,9 @@
using NNlib: conv
expand(::Type{Val{N}}, i::Integer) where N = ntuple(_ -> i, Val{N})
expand(::Type{Val{N}}, i::NTuple{N, Integer}) where N = i
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
"""
Conv(size, in=>out)
@ -24,9 +26,9 @@ struct Conv{N,F,A,V}
dilation::NTuple{N,Int}
end
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where T =
Conv(σ, w, b, expand.(Val{ndims(w)-2}, (stride, pad, dilation))...)
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N =