add hessian
This commit is contained in:
parent
8386a49bf9
commit
9e553adbf7
|
@ -6,7 +6,7 @@ using MacroTools: @q, @forward
|
|||
import Base: ==
|
||||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
|
||||
param, back!
|
||||
jacobian, hessian, param, back!
|
||||
|
||||
tracker(x) = nothing
|
||||
|
||||
|
|
|
@ -181,3 +181,28 @@ gradient(f, xs...; nest = false) =
|
|||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
||||
|
||||
# Jacobians and Hessians
|
||||
|
||||
import ..Flux
|
||||
|
||||
"""
|
||||
J = jacobian(m,x)
|
||||
|
||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||
"""
|
||||
function jacobian(m,x)
|
||||
xp = param(x)
|
||||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(undef,k,n)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||
J[i,:] = xp.grad
|
||||
xp.grad .= 0 # Reset gradient accumulator
|
||||
end
|
||||
J
|
||||
end
|
||||
|
||||
hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x)
|
||||
|
|
19
src/utils.jl
19
src/utils.jl
|
@ -139,25 +139,6 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||
end
|
||||
end
|
||||
|
||||
"""
|
||||
J = jacobian(m,x)
|
||||
|
||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||
"""
|
||||
function jacobian(m,x)
|
||||
xp = param(x)
|
||||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(undef,n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i], once = false) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
xp.grad .= 0 # Reset gradient accumulator
|
||||
end
|
||||
J'
|
||||
end
|
||||
|
||||
"""
|
||||
@jit ...
|
||||
|
||||
|
|
Loading…
Reference in New Issue