Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4f988d5c20 |
@ -42,6 +42,13 @@ function stop()
|
|||||||
throw(StopException())
|
throw(StopException())
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function step!(loss, ps, minibatch, opt)
|
||||||
|
gs = gradient(ps) do
|
||||||
|
loss(minibatch...)
|
||||||
|
end
|
||||||
|
update!(opt, ps, gs)
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, params, data, opt; cb)
|
train!(loss, params, data, opt; cb)
|
||||||
|
|
||||||
@ -65,10 +72,11 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
try
|
try
|
||||||
gs = gradient(ps) do
|
# gs = gradient(ps) do
|
||||||
loss(d...)
|
# loss(d...)
|
||||||
end
|
# end
|
||||||
update!(opt, ps, gs)
|
# update!(opt, ps, gs)
|
||||||
|
step!(loss, ps, d, opt)
|
||||||
cb()
|
cb()
|
||||||
catch ex
|
catch ex
|
||||||
if ex isa StopException
|
if ex isa StopException
|
||||||
|
Loading…
Reference in New Issue
Block a user