Merge pull request #124 from baggepinnen/jacobian

Add jacobian function
This commit is contained in:
Mike J Innes 2017-12-13 17:07:57 +00:00 committed by GitHub
commit 9c7c9d2342
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 1 deletions

View File

@ -122,3 +122,22 @@ function throttle(f, timeout; leading=true, trailing=false)
nothing nothing
end end
end end
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(n,k)
for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator
end
J'
end

View File

@ -48,6 +48,15 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal
end end
end end
@testset "Jacobian" begin
A = param(randn(2,2))
x = randn(2)
m(x) = A*x
y = m(x)
J = jacobian(m,x)
@test J A.data
end
@testset "Initialization" begin @testset "Initialization" begin
# Set random seed so that these tests don't fail randomly # Set random seed so that these tests don't fail randomly
srand(0) srand(0)
@ -69,4 +78,4 @@ end
@test std(v) > 0.9*sqrt(2/(n_in + n_out)) @test std(v) > 0.9*sqrt(2/(n_in + n_out))
@test std(v) < 1.1*sqrt(2/(n_in + n_out)) @test std(v) < 1.1*sqrt(2/(n_in + n_out))
end end
end end