diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index e32efebf..beb7453e 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -383,7 +383,7 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad conv(x, w; kw...) = conv(data(x), data(w); kw...), Δ -> nobacksies(:conv, - (NNlib.∇conv_data(data.((Δ, w))..., size(x); kw...), + (NNlib.∇conv_data(data.((Δ, w))...; size=size(x), kw...), NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) ∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...) @@ -393,7 +393,7 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad ∇conv_data(x, w; kw...) = ∇conv_data(data(x), data(w); kw...), Δ -> nobacksies(:conv, - (NNlib.conv(data.((Δ, w))..., size(x); kw...), + (NNlib.conv(data.((Δ, w))...; size=size(x), kw...), NNlib.∇conv_filter(data.((x, Δ, w))...; kw...))) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)