diff --git a/src/utils.jl b/src/utils.jl index 944d35bf..9a03ae4f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -122,3 +122,22 @@ function throttle(f, timeout; leading=true, trailing=false) nothing end end + +""" + J = jacobian(m,x) + +Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])` +""" +function jacobian(m,x) + xp = param(x) + y = m(xp) + k = length(y) + n = length(x) + J = Matrix{eltype(x)}(n,k) + for i = 1:k + Flux.back!(y[i]) # Populate gradient accumulator + J[:,i] = xp.grad + xp.grad .*= 0 # Reset gradient accumulator + end + J' +end diff --git a/test/utils.jl b/test/utils.jl index 1c313a3d..34762adf 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,6 +48,15 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal end end +@testset "Jacobian" begin + A = param(randn(2,2)) + x = randn(2) + m(x) = A*x + y = m(x) + J = jacobian(m,x) + @test J ≈ A.data +end + @testset "Initialization" begin # Set random seed so that these tests don't fail randomly srand(0) @@ -69,4 +78,4 @@ end @test std(v) > 0.9*sqrt(2/(n_in + n_out)) @test std(v) < 1.1*sqrt(2/(n_in + n_out)) end -end \ No newline at end of file +end