Compare commits
1 Commits
Author | SHA1 | Date |
---|---|---|
![]() |
4f988d5c20 |
|
@ -42,6 +42,13 @@ function stop()
|
|||
throw(StopException())
|
||||
end
|
||||
|
||||
function step!(loss, ps, minibatch, opt)
|
||||
gs = gradient(ps) do
|
||||
loss(minibatch...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
end
|
||||
|
||||
"""
|
||||
train!(loss, params, data, opt; cb)
|
||||
|
||||
|
@ -65,10 +72,11 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||
cb = runall(cb)
|
||||
@progress for d in data
|
||||
try
|
||||
gs = gradient(ps) do
|
||||
loss(d...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
# gs = gradient(ps) do
|
||||
# loss(d...)
|
||||
# end
|
||||
# update!(opt, ps, gs)
|
||||
step!(loss, ps, d, opt)
|
||||
cb()
|
||||
catch ex
|
||||
if ex isa StopException
|
||||
|
|
Loading…
Reference in New Issue