fix signs
This commit is contained in:
parent
8224c77f7d
commit
e17d1cbe7a
@ -70,7 +70,7 @@ end
|
|||||||
function Flux.update!(model::MXModel, η)
|
function Flux.update!(model::MXModel, η)
|
||||||
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
||||||
mx.@nd_as_jl rw = (arg, grad) begin
|
mx.@nd_as_jl rw = (arg, grad) begin
|
||||||
arg .+= grad .* η
|
arg .-= grad .* η
|
||||||
grad[:] = 0
|
grad[:] = 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -24,7 +24,7 @@ function accumulate!(p::Param, Δ)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function update!(p::Param, η)
|
function update!(p::Param, η)
|
||||||
p.x .+= p.Δx .* η
|
p.x .-= p.Δx .* η
|
||||||
p.Δx[:] = 0
|
p.Δx[:] = 0
|
||||||
return p
|
return p
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user