updated tests

This commit is contained in:
Dhairya Gandhi 2018-09-14 20:32:56 +05:30
parent 4860c1d48b
commit 63bc71698b
5 changed files with 18 additions and 12 deletions

View File

@ -1,7 +1,7 @@
module Optimise
export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp
Descent, ADAM, Momentum, Nesterov, RMSProp, stop, StopException
include("optimisers.jl")
include("train.jl")

View File

@ -102,6 +102,7 @@ function update!(o::ADAM, x, Δ)
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
o.state[x] = (mt, vt, βp .* β)
return Δ
end
# """

View File

@ -1,5 +1,6 @@
using Juno
using Flux.Tracker: data, grad, back!
import Base.depwarn
function update!(opt, xs)
for x in xs

View File

@ -74,6 +74,7 @@ include("numeric.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`."""
@ -82,6 +83,7 @@ hook(f, x) = istracked(x) ? track(hook, f, x) : x
"""
checkpoint(f, args...)
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
calculating gradients. Instead, `f(args...)` will be called again during the
backward pass. This can be used to save memory in larger models.

View File

@ -3,18 +3,20 @@ using Flux.Tracker
using Test
@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [Descent, Nesterov, RMSProp, ADAM, Momentum]
w = param(randn(10, 10))
delta = param(Tracker.similar(w))
loss(x) = Flux.mse(w*x, w*x)
@testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001)
if opt isa Descent
opt = Opt(0.1)
for t=1:10^5
l = loss(rand(10))
back!(l)
update!(opt, w.data, delta.data)
w .-= delta
end
@test Flux.mse(w, w) < 0.01
end
for t = 1: 10^5
l = loss(rand(10))
back!(l)
delta = Optimise.update!(opt, w.data, w.grad)
w.data .-= delta
end
@test Flux.mse(w, w) < 0.01
end
end