From 0dc4ec4d6beb8f080e89badafc4f39690e40b2d2 Mon Sep 17 00:00:00 2001 From: Tejan Karmali Date: Wed, 24 Oct 2018 07:04:49 -0400 Subject: [PATCH] conv_data grad api change --- src/tracker/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 9353e7ff..fe8a5fc4 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -358,7 +358,7 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad conv(x, w; kw...) = conv(data(x), data(w); kw...), Δ -> nobacksies(:conv, - (NNlib.∇conv_data(data.((Δ, x, w))...; kw...), + (NNlib.∇conv_data(data.((Δ, w))..., size(x); kw...), NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) ∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...) @@ -368,7 +368,7 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad ∇conv_data(x, w; kw...) = ∇conv_data(data(x), data(w); kw...), Δ -> nobacksies(:conv, - (NNlib.conv(data.((x, Δ, w))...; kw...), + (NNlib.conv(data.((Δ, w))..., size(x); kw...), NNlib.∇conv_filter(data.((x, Δ, w))...; kw...))) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)