diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 84bef9f6..994648c2 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -33,7 +33,8 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = Flux.treelike(Conv) function (c::Conv)(x) - ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) + # TODO: breaks gpu broadcast :( + # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) end