fixes transpose/ adjoint gradient

This commit is contained in:
Dhairya Gandhi 2019-01-15 21:59:32 +05:30
parent 4d79f499bf
commit 0060cc3453

View File

@ -121,8 +121,8 @@ Base.:-(xs::TrackedArray) = track(-, xs)
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)