better numeric grads

This commit is contained in:
Mike J Innes 2017-09-03 17:10:35 -04:00
parent 8f4ccdd5ba
commit 788d7d35f0
2 changed files with 12 additions and 8 deletions

View File

@ -5,15 +5,17 @@ function gradient(f, xs::AbstractArray...)
end end
function ngradient(f, xs::AbstractArray...) function ngradient(f, xs::AbstractArray...)
y = f(xs...)
grads = zeros.(xs) grads = zeros.(xs)
for (x, Δ) in zip(xs, grads) for (x, Δ) in zip(xs, grads)
for i in 1:length(x) for i in 1:length(x)
δ = sqrt(eps()) δ = sqrt(eps())
tmp, x[i] = x[i], x[i]+δ tmp = x[i]
y = f(xs...) x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
x[i] = tmp x[i] = tmp
Δ[i] = (y-y)/δ Δ[i] = (y2-y1)/δ
end end
end end
return grads return grads

View File

@ -11,10 +11,12 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), 3)
gradtest(x -> softmax(x).*(1:3), (3,5)) @test gradtest(x -> softmax(x).*(1:3), (3,5))
gradtest(Flux.mse, rand(5,5), rand(5, 5)) @test gradtest(Flux.mse, rand(5,5), rand(5, 5))
gradtest(Flux.logloss, rand(5,5), rand(5, 5)) @test gradtest(Flux.logloss, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5))
end end