better numeric grads
This commit is contained in:
parent
8f4ccdd5ba
commit
788d7d35f0
@ -5,15 +5,17 @@ function gradient(f, xs::AbstractArray...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function ngradient(f, xs::AbstractArray...)
|
function ngradient(f, xs::AbstractArray...)
|
||||||
y = f(xs...)
|
|
||||||
grads = zeros.(xs)
|
grads = zeros.(xs)
|
||||||
for (x, Δ) in zip(xs, grads)
|
for (x, Δ) in zip(xs, grads)
|
||||||
for i in 1:length(x)
|
for i in 1:length(x)
|
||||||
δ = sqrt(eps())
|
δ = sqrt(eps())
|
||||||
tmp, x[i] = x[i], x[i]+δ
|
tmp = x[i]
|
||||||
y′ = f(xs...)
|
x[i] = tmp - δ/2
|
||||||
|
y1 = f(xs...)
|
||||||
|
x[i] = tmp + δ/2
|
||||||
|
y2 = f(xs...)
|
||||||
x[i] = tmp
|
x[i] = tmp
|
||||||
Δ[i] = (y′-y)/δ
|
Δ[i] = (y2-y1)/δ
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return grads
|
return grads
|
||||||
|
@ -11,10 +11,12 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
|
|
||||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||||
|
|
||||||
gradtest(x -> softmax(x).*(1:3), 3)
|
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||||
gradtest(x -> softmax(x).*(1:3), (3,5))
|
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||||
|
|
||||||
gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||||
gradtest(Flux.logloss, rand(5,5), rand(5, 5))
|
@test gradtest(Flux.logloss, rand(5,5), rand(5, 5))
|
||||||
|
|
||||||
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user