gradient hook

This commit is contained in:
Mike J Innes 2018-07-02 13:17:46 +01:00
parent 5d8b63dc65
commit ce88273880
1 changed files with 10 additions and 0 deletions

View File

@ -58,6 +58,16 @@ include("scalar.jl")
include("array.jl")
include("numeric.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))