Use lower level conv interface
This commit is contained in:
parent
f98d289579
commit
77bb2a66de
@ -41,7 +41,7 @@ function (c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation})(x) where {stride,
|
||||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, stride)..., :, 1)
|
||||
σ.(conv(x, c.weight, stride = stride, pad = pad, dilation = dilation) .+ b)
|
||||
σ.(NNlib._conv{pad, stride, dilation}()(x, c.weight) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Conv)
|
||||
|
Loading…
Reference in New Issue
Block a user