updated tests
This commit is contained in:
parent
4860c1d48b
commit
63bc71698b
|
@ -1,7 +1,7 @@
|
|||
module Optimise
|
||||
|
||||
export train!,
|
||||
Descent, ADAM, Momentum, Nesterov, RMSProp
|
||||
Descent, ADAM, Momentum, Nesterov, RMSProp, stop, StopException
|
||||
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
|
|
|
@ -102,6 +102,7 @@ function update!(o::ADAM, x, Δ)
|
|||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
||||
o.state[x] = (mt, vt, βp .* β)
|
||||
return Δ
|
||||
end
|
||||
|
||||
# """
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using Juno
|
||||
using Flux.Tracker: data, grad, back!
|
||||
import Base.depwarn
|
||||
|
||||
function update!(opt, xs)
|
||||
for x in xs
|
||||
|
|
|
@ -74,6 +74,7 @@ include("numeric.jl")
|
|||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`."""
|
||||
|
@ -82,6 +83,7 @@ hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
|||
|
||||
"""
|
||||
checkpoint(f, args...)
|
||||
|
||||
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
||||
calculating gradients. Instead, `f(args...)` will be called again during the
|
||||
backward pass. This can be used to save memory in larger models.
|
||||
|
|
|
@ -3,18 +3,20 @@ using Flux.Tracker
|
|||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [Descent, Nesterov, RMSProp, ADAM, Momentum]
|
||||
w′ = param(randn(10, 10))
|
||||
delta = param(Tracker.similar(w′))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
@testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt(0.001)
|
||||
if opt isa Descent
|
||||
opt = Opt(0.1)
|
||||
for t=1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
update!(opt, w′.data, delta.data)
|
||||
w′ .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
for t = 1: 10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue