Merge pull request #124 from baggepinnen/jacobian
Add jacobian function
This commit is contained in:
commit
9c7c9d2342
19
src/utils.jl
19
src/utils.jl
@ -122,3 +122,22 @@ function throttle(f, timeout; leading=true, trailing=false)
|
||||
nothing
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
J = jacobian(m,x)
|
||||
|
||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||
"""
|
||||
function jacobian(m,x)
|
||||
xp = param(x)
|
||||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i]) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
xp.grad .*= 0 # Reset gradient accumulator
|
||||
end
|
||||
J'
|
||||
end
|
||||
|
@ -48,6 +48,15 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Jacobian" begin
|
||||
A = param(randn(2,2))
|
||||
x = randn(2)
|
||||
m(x) = A*x
|
||||
y = m(x)
|
||||
J = jacobian(m,x)
|
||||
@test J ≈ A.data
|
||||
end
|
||||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
srand(0)
|
||||
@ -69,4 +78,4 @@ end
|
||||
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user