gradient checks

This commit is contained in:
Mike J Innes 2017-08-23 01:43:45 +01:00
parent 416e0a3d4c
commit 5eee653a64
4 changed files with 38 additions and 0 deletions

View File

@ -77,5 +77,6 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru
end
include("lib.jl")
include("numeric.jl")
end

22
src/tracker/numeric.jl Normal file
View File

@ -0,0 +1,22 @@
function gradient(f, xs::AbstractArray...)
xs = track.(xs)
back!(f(xs...), [1])
grad.(xs)
end
function ngradient(f, xs::AbstractArray...)
y = f(xs...)
grads = zeros.(xs)
for (x, Δ) in zip(xs, grads)
for i in 1:length(x)
δ = sqrt(eps())
tmp, x[i] = x[i], x[i]+δ
y = f(xs...)
x[i] = tmp
Δ[i] = (y-y)/δ
end
end
return grads
end
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))

View File

@ -4,5 +4,6 @@ using Flux, Base.Test
include("compiler.jl")
include("utils.jl")
include("tracker.jl")
end

14
test/tracker.jl Normal file
View File

@ -0,0 +1,14 @@
using Flux.Tracker: gradcheck
using Base.Test, NNlib
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
end