Use lower level conv interface

This commit is contained in:
Keno Fischer 2018-10-23 16:15:39 -04:00
parent f98d289579
commit 77bb2a66de

View File

@ -41,7 +41,7 @@ function (c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation})(x) where {stride,
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, stride)..., :, 1)
σ.(conv(x, c.weight, stride = stride, pad = pad, dilation = dilation) .+ b)
σ.(NNlib._conv{pad, stride, dilation}()(x, c.weight) .+ b)
end
function Base.show(io::IO, l::Conv)