From 0060cc345374565545d63c074429a03654cab749 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 15 Jan 2019 21:59:32 +0530 Subject: [PATCH] fixes transpose/ adjoint gradient --- src/tracker/lib/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 08a40db7..01d9bd23 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -121,8 +121,8 @@ Base.:-(xs::TrackedArray) = track(-, xs) Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.adjoint(xs::TrackedArray) = track(adjoint, xs) -@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),) -@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) +@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),) +@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),) Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)