initial step fn

This commit is contained in:
Dhairya Gandhi 2020-02-03 14:30:55 +05:30
parent ddc2c20e68
commit 4f988d5c20

View File

@ -42,6 +42,13 @@ function stop()
throw(StopException()) throw(StopException())
end end
function step!(loss, ps, minibatch, opt)
gs = gradient(ps) do
loss(minibatch...)
end
update!(opt, ps, gs)
end
""" """
train!(loss, params, data, opt; cb) train!(loss, params, data, opt; cb)
@ -65,10 +72,11 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
@progress for d in data @progress for d in data
try try
gs = gradient(ps) do # gs = gradient(ps) do
loss(d...) # loss(d...)
end # end
update!(opt, ps, gs) # update!(opt, ps, gs)
step!(loss, ps, d, opt)
cb() cb()
catch ex catch ex
if ex isa StopException if ex isa StopException