diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 63146f5f..2d805af9 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -87,7 +87,7 @@ Hook into gradient backpropagation. `x` is unmodified, but when backpropagating the sign of the gradient applied to `x`. """ hook(f, x) = istracked(x) ? track(hook, f, x) : x -@grad hook(f, x) = x, Δ -> (nothing, f(Δ)) +@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ)) """ checkpoint(f, args...)