updated tests

This commit is contained in:
Dhairya Gandhi 2018-09-14 20:32:56 +05:30
parent 4860c1d48b
commit 63bc71698b
5 changed files with 18 additions and 12 deletions

View File

@ -1,7 +1,7 @@
module Optimise module Optimise
export train!, export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp Descent, ADAM, Momentum, Nesterov, RMSProp, stop, StopException
include("optimisers.jl") include("optimisers.jl")
include("train.jl") include("train.jl")

View File

@ -102,6 +102,7 @@ function update!(o::ADAM, x, Δ)
@. vt = β[2] * vt + (1 - β[2]) * Δ^2 @. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η @. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
o.state[x] = (mt, vt, βp .* β) o.state[x] = (mt, vt, βp .* β)
return Δ
end end
# """ # """

View File

@ -1,5 +1,6 @@
using Juno using Juno
using Flux.Tracker: data, grad, back! using Flux.Tracker: data, grad, back!
import Base.depwarn
function update!(opt, xs) function update!(opt, xs)
for x in xs for x in xs

View File

@ -74,6 +74,7 @@ include("numeric.jl")
""" """
hook(f, x) -> x hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse `f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.""" the sign of the gradient applied to `x`."""
@ -82,6 +83,7 @@ hook(f, x) = istracked(x) ? track(hook, f, x) : x
""" """
checkpoint(f, args...) checkpoint(f, args...)
Behaves like `f(args...)`, but avoids storing the intermediate values needed for Behaves like `f(args...)`, but avoids storing the intermediate values needed for
calculating gradients. Instead, `f(args...)` will be called again during the calculating gradients. Instead, `f(args...)` will be called again during the
backward pass. This can be used to save memory in larger models. backward pass. This can be used to save memory in larger models.

View File

@ -3,16 +3,18 @@ using Flux.Tracker
using Test using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [Descent, Nesterov, RMSProp, ADAM, Momentum] @testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum]
w = param(randn(10, 10)) w = param(randn(10, 10))
delta = param(Tracker.similar(w))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt(0.001)
if opt isa Descent
opt = Opt(0.1) opt = Opt(0.1)
for t=1:10^5 end
for t = 1: 10^5
l = loss(rand(10)) l = loss(rand(10))
back!(l) back!(l)
update!(opt, w.data, delta.data) delta = Optimise.update!(opt, w.data, w.grad)
w .-= delta w.data .-= delta
end end
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
end end