diff --git a/src/tracker/array.jl b/src/tracker/array.jl index fe8a5fc4..28941c11 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -358,7 +358,7 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad conv(x, w; kw...) = conv(data(x), data(w); kw...), Δ -> nobacksies(:conv, - (NNlib.∇conv_data(data.((Δ, w))..., size(x); kw...), + (NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...), NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) ∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...) @@ -368,7 +368,7 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad ∇conv_data(x, w; kw...) = ∇conv_data(data(x), data(w); kw...), Δ -> nobacksies(:conv, - (NNlib.conv(data.((Δ, w))..., size(x); kw...), + (NNlib.conv(data.((Δ, w))...; size=size(x), kw...), NNlib.∇conv_filter(data.((x, Δ, w))...; kw...))) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)