change train
This commit is contained in:
parent
c1f0c29026
commit
f8c8bb4e35
|
@ -4,3 +4,4 @@
|
|||
docs/build/
|
||||
docs/site/
|
||||
deps
|
||||
Manifest.toml
|
||||
|
|
|
@ -56,14 +56,17 @@ function stop()
|
|||
throw(StopException())
|
||||
end
|
||||
|
||||
maketuple(x) = (x,)
|
||||
maketuple(x::Tuple) = x
|
||||
|
||||
"""
|
||||
train!(loss, params, data, opt; cb)
|
||||
|
||||
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
|
||||
backpropagation and call the optimizer `opt`.
|
||||
For each datapoint `d` in `data`, assumed to be a tuple, compute the gradient of `loss(d...)`
|
||||
with respect to `params`, and call the optimizer `opt`.
|
||||
|
||||
In case datapoints `d` are of numeric array type, assume no splatting is needed
|
||||
and compute the gradient of `loss(d)`.
|
||||
If `data` yields a tuple mini-batch `d` under iteration, it will be splatted in the function call
|
||||
`loss(d...)`, otherwise `loss(d)` will be called for non-tuple mini-batches.
|
||||
|
||||
A callback is given with the keyword argument `cb`. For example, this will print
|
||||
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
||||
|
@ -80,14 +83,8 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||
cb = runall(cb)
|
||||
@progress for d in data
|
||||
try
|
||||
if d isa AbstractArray{<:Number}
|
||||
gs = gradient(ps) do
|
||||
loss(d)
|
||||
end
|
||||
else
|
||||
gs = gradient(ps) do
|
||||
loss(d...)
|
||||
end
|
||||
gs = gradient(ps) do
|
||||
loss(maketuple(d)...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
cb()
|
||||
|
|
|
@ -2,7 +2,6 @@ using Flux
|
|||
using Flux.Data
|
||||
using Test
|
||||
using Random, Statistics, LinearAlgebra
|
||||
using Documenter
|
||||
using IterTools: ncycle
|
||||
|
||||
Random.seed!(0)
|
||||
|
@ -38,9 +37,10 @@ end
|
|||
end
|
||||
end
|
||||
|
||||
@testset "Docs" begin
|
||||
if VERSION >= v"1.4"
|
||||
@static if VERSION >= v"1.4"
|
||||
using Documenter
|
||||
@testset "Docs" begin
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
doctest(Flux)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue