diff --git a/src/layers/params.jl b/src/layers/params.jl index 6c78fa69..a8e42f93 100644 --- a/src/layers/params.jl +++ b/src/layers/params.jl @@ -6,7 +6,6 @@ end param(x) = Param(x, zero(x)) state(p::Param) = p.x -state(x) = x function accumulate!(p::Param, Δ) p.Δx .+= Δ @@ -14,8 +13,9 @@ function accumulate!(p::Param, Δ) end function update!(p::Param, η) - p.x .+= p.Δx * η + p.x .+= p.Δx .* η return p end +state(x) = x accumulate!(x, Δ) = x