Flux.jl/src/tracker/numeric.jl

25 lines
513 B
Julia
Raw Normal View History

2017-08-23 00:43:45 +00:00
function gradient(f, xs::AbstractArray...)
xs = track.(xs)
2017-08-27 08:49:42 +00:00
back!(f(xs...))
2017-08-23 00:43:45 +00:00
grad.(xs)
end
function ngradient(f, xs::AbstractArray...)
grads = zeros.(xs)
for (x, Δ) in zip(xs, grads)
for i in 1:length(x)
δ = sqrt(eps())
2017-09-03 21:10:35 +00:00
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
2017-08-23 00:43:45 +00:00
x[i] = tmp
2017-09-03 21:10:35 +00:00
Δ[i] = (y2-y1)/δ
2017-08-23 00:43:45 +00:00
end
end
return grads
end
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))