tweaks
This commit is contained in:
parent
568869b9bf
commit
f2052739c1
|
@ -3,5 +3,5 @@
|
|||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
# back!(::typeof(mse), Δ, ŷ, y) = Δ .* (ŷ .- y)
|
||||
|
||||
logloss(ŷ, y) = -sum(y .* log.(ŷ))
|
||||
logloss(ŷ, y) = -sum(y .* log.(ŷ)) / size(y, 2)
|
||||
# back!(::typeof(logloss), Δ, ŷ, y) = 0 .- Δ .* y ./ ŷ
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
function descent(p::Param, η::Real)
|
||||
() -> p.x .-= p.Δ .* η
|
||||
function ()
|
||||
p.x .-= p.Δ .* η
|
||||
p.Δ .= 0
|
||||
end
|
||||
end
|
||||
|
||||
function momentum(p::Param, ρ::Real)
|
||||
|
|
|
@ -7,7 +7,10 @@ tocb(fs::AbstractVector) = () -> foreach(call, fs)
|
|||
function train!(m, data, opt; cb = () -> ())
|
||||
cb = tocb(cb)
|
||||
@progress for x in data
|
||||
back!(m(x...))
|
||||
l = m(x...)
|
||||
isinf(l.data[]) && error("Inf")
|
||||
isnan(l.data[]) && error("NaN")
|
||||
back!(l)
|
||||
opt()
|
||||
cb()
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue