Compare commits

...

1 Commits

Author SHA1 Message Date
Dhairya Gandhi
4f988d5c20 initial step fn 2020-02-03 14:30:55 +05:30

View File

@ -42,6 +42,13 @@ function stop()
throw(StopException()) throw(StopException())
end end
function step!(loss, ps, minibatch, opt)
gs = gradient(ps) do
loss(minibatch...)
end
update!(opt, ps, gs)
end
""" """
train!(loss, params, data, opt; cb) train!(loss, params, data, opt; cb)
@ -65,10 +72,11 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
@progress for d in data @progress for d in data
try try
gs = gradient(ps) do # gs = gradient(ps) do
loss(d...) # loss(d...)
end # end
update!(opt, ps, gs) # update!(opt, ps, gs)
step!(loss, ps, d, opt)
cb() cb()
catch ex catch ex
if ex isa StopException if ex isa StopException