gradient hook
This commit is contained in:
parent
5d8b63dc65
commit
ce88273880
|
@ -58,6 +58,16 @@ include("scalar.jl")
|
|||
include("array.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
|
|
Loading…
Reference in New Issue