From 7726a5b605832b0fc35e26332a0f0f83a5d5f210 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 26 Jun 2018 14:05:07 +0100 Subject: [PATCH] inferrable --- src/layers/conv.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 7548fc96..38310aad 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,7 +1,9 @@ using NNlib: conv -expand(::Type{Val{N}}, i::Integer) where N = ntuple(_ -> i, Val{N}) -expand(::Type{Val{N}}, i::NTuple{N, Integer}) where N = i +@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)}) + +expand(N, i::Tuple) = i +expand(N, i::Integer) = ntuple(_ -> i, N) """ Conv(size, in=>out) @@ -24,9 +26,9 @@ struct Conv{N,F,A,V} dilation::NTuple{N,Int} end -Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0, dilation = 1) where T = - Conv(σ, w, b, expand.(Val{ndims(w)-2}, (stride, pad, dilation))...) +Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} = + Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0, dilation = 1) where N =