diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 73c63029..65cab8b8 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -5,15 +5,17 @@ function gradient(f, xs::AbstractArray...) end function ngradient(f, xs::AbstractArray...) - y = f(xs...) grads = zeros.(xs) for (x, Δ) in zip(xs, grads) for i in 1:length(x) δ = sqrt(eps()) - tmp, x[i] = x[i], x[i]+δ - y′ = f(xs...) + tmp = x[i] + x[i] = tmp - δ/2 + y1 = f(xs...) + x[i] = tmp + δ/2 + y2 = f(xs...) x[i] = tmp - Δ[i] = (y′-y)/δ + Δ[i] = (y2-y1)/δ end end return grads diff --git a/test/tracker.jl b/test/tracker.jl index 1e64b23a..258e1af4 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -11,10 +11,12 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) -gradtest(x -> softmax(x).*(1:3), 3) -gradtest(x -> softmax(x).*(1:3), (3,5)) +@test gradtest(x -> softmax(x).*(1:3), 3) +@test gradtest(x -> softmax(x).*(1:3), (3,5)) -gradtest(Flux.mse, rand(5,5), rand(5, 5)) -gradtest(Flux.logloss, rand(5,5), rand(5, 5)) +@test gradtest(Flux.mse, rand(5,5), rand(5, 5)) +@test gradtest(Flux.logloss, rand(5,5), rand(5, 5)) + +@test gradtest(x -> x', rand(5)) end