gradient checks
This commit is contained in:
parent
416e0a3d4c
commit
5eee653a64
@ -77,5 +77,6 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru
|
||||
end
|
||||
|
||||
include("lib.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
end
|
||||
|
22
src/tracker/numeric.jl
Normal file
22
src/tracker/numeric.jl
Normal file
@ -0,0 +1,22 @@
|
||||
function gradient(f, xs::AbstractArray...)
|
||||
xs = track.(xs)
|
||||
back!(f(xs...), [1])
|
||||
grad.(xs)
|
||||
end
|
||||
|
||||
function ngradient(f, xs::AbstractArray...)
|
||||
y = f(xs...)
|
||||
grads = zeros.(xs)
|
||||
for (x, Δ) in zip(xs, grads)
|
||||
for i in 1:length(x)
|
||||
δ = sqrt(eps())
|
||||
tmp, x[i] = x[i], x[i]+δ
|
||||
y′ = f(xs...)
|
||||
x[i] = tmp
|
||||
Δ[i] = (y′-y)/δ
|
||||
end
|
||||
end
|
||||
return grads
|
||||
end
|
||||
|
||||
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))
|
@ -4,5 +4,6 @@ using Flux, Base.Test
|
||||
|
||||
include("compiler.jl")
|
||||
include("utils.jl")
|
||||
include("tracker.jl")
|
||||
|
||||
end
|
||||
|
14
test/tracker.jl
Normal file
14
test/tracker.jl
Normal file
@ -0,0 +1,14 @@
|
||||
using Flux.Tracker: gradcheck
|
||||
using Base.Test, NNlib
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
||||
@testset "Tracker" begin
|
||||
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||
|
||||
end
|
Loading…
Reference in New Issue
Block a user