Merge pull request #124 from baggepinnen/jacobian
Add jacobian function
This commit is contained in:
commit
9c7c9d2342
19
src/utils.jl
19
src/utils.jl
@ -122,3 +122,22 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||||||
nothing
|
nothing
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
J = jacobian(m,x)
|
||||||
|
|
||||||
|
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||||
|
"""
|
||||||
|
function jacobian(m,x)
|
||||||
|
xp = param(x)
|
||||||
|
y = m(xp)
|
||||||
|
k = length(y)
|
||||||
|
n = length(x)
|
||||||
|
J = Matrix{eltype(x)}(n,k)
|
||||||
|
for i = 1:k
|
||||||
|
Flux.back!(y[i]) # Populate gradient accumulator
|
||||||
|
J[:,i] = xp.grad
|
||||||
|
xp.grad .*= 0 # Reset gradient accumulator
|
||||||
|
end
|
||||||
|
J'
|
||||||
|
end
|
||||||
|
@ -48,6 +48,15 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Jacobian" begin
|
||||||
|
A = param(randn(2,2))
|
||||||
|
x = randn(2)
|
||||||
|
m(x) = A*x
|
||||||
|
y = m(x)
|
||||||
|
J = jacobian(m,x)
|
||||||
|
@test J ≈ A.data
|
||||||
|
end
|
||||||
|
|
||||||
@testset "Initialization" begin
|
@testset "Initialization" begin
|
||||||
# Set random seed so that these tests don't fail randomly
|
# Set random seed so that these tests don't fail randomly
|
||||||
srand(0)
|
srand(0)
|
||||||
@ -69,4 +78,4 @@ end
|
|||||||
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user