updated tests
This commit is contained in:
parent
4860c1d48b
commit
63bc71698b
@ -1,7 +1,7 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
Descent, ADAM, Momentum, Nesterov, RMSProp
|
Descent, ADAM, Momentum, Nesterov, RMSProp, stop, StopException
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
@ -102,6 +102,7 @@ function update!(o::ADAM, x, Δ)
|
|||||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||||
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η
|
||||||
o.state[x] = (mt, vt, βp .* β)
|
o.state[x] = (mt, vt, βp .* β)
|
||||||
|
return Δ
|
||||||
end
|
end
|
||||||
|
|
||||||
# """
|
# """
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: data, grad, back!
|
using Flux.Tracker: data, grad, back!
|
||||||
|
import Base.depwarn
|
||||||
|
|
||||||
function update!(opt, xs)
|
function update!(opt, xs)
|
||||||
for x in xs
|
for x in xs
|
||||||
|
@ -74,6 +74,7 @@ include("numeric.jl")
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
hook(f, x) -> x′
|
hook(f, x) -> x′
|
||||||
|
|
||||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||||
the sign of the gradient applied to `x`."""
|
the sign of the gradient applied to `x`."""
|
||||||
@ -82,6 +83,7 @@ hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
checkpoint(f, args...)
|
checkpoint(f, args...)
|
||||||
|
|
||||||
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
||||||
calculating gradients. Instead, `f(args...)` will be called again during the
|
calculating gradients. Instead, `f(args...)` will be called again during the
|
||||||
backward pass. This can be used to save memory in larger models.
|
backward pass. This can be used to save memory in larger models.
|
||||||
|
@ -3,18 +3,20 @@ using Flux.Tracker
|
|||||||
using Test
|
using Test
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [Descent, Nesterov, RMSProp, ADAM, Momentum]
|
@testset for Opt in [Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
delta = param(Tracker.similar(w′))
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
opt = Opt(0.001)
|
||||||
|
if opt isa Descent
|
||||||
opt = Opt(0.1)
|
opt = Opt(0.1)
|
||||||
for t=1:10^5
|
end
|
||||||
l = loss(rand(10))
|
for t = 1: 10^5
|
||||||
back!(l)
|
l = loss(rand(10))
|
||||||
update!(opt, w′.data, delta.data)
|
back!(l)
|
||||||
w′ .-= delta
|
delta = Optimise.update!(opt, w′.data, w′.grad)
|
||||||
end
|
w′.data .-= delta
|
||||||
@test Flux.mse(w, w′) < 0.01
|
end
|
||||||
|
@test Flux.mse(w, w′) < 0.01
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user