fixes transpose/ adjoint gradient
This commit is contained in:
parent
4d79f499bf
commit
0060cc3453
@ -121,8 +121,8 @@ Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
||||
|
||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user