Merge pull request #517 from FluxML/fix_adamw
Fix decay argument in ADAMW
This commit is contained in:
commit
f6397e7358
@ -228,7 +228,7 @@ end
|
|||||||
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
[ADAMW](https://arxiv.org/abs/1711.05101) fixing weight decay regularization in Adam.
|
||||||
"""
|
"""
|
||||||
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) =
|
||||||
Optimiser(ADAM(η, β), WeightDecay(wd))
|
Optimiser(ADAM(η, β), WeightDecay(decay))
|
||||||
|
|
||||||
# Compose optimizers
|
# Compose optimizers
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ using Flux.Tracker
|
|||||||
using Test
|
using Test
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
@testset for Opt in [ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, Descent, ADAM, Nesterov, RMSProp, Momentum]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
opt = Opt(0.001)
|
opt = Opt(0.001)
|
||||||
|
Loading…
Reference in New Issue
Block a user