gradient hook
This commit is contained in:
parent
5d8b63dc65
commit
ce88273880
@ -58,6 +58,16 @@ include("scalar.jl")
|
|||||||
include("array.jl")
|
include("array.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
|
"""
|
||||||
|
hook(f, x) -> x′
|
||||||
|
|
||||||
|
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||||
|
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||||
|
the sign of the gradient applied to `x`.
|
||||||
|
"""
|
||||||
|
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||||
|
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
|
||||||
|
|
||||||
param(x::Number) = TrackedReal(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user