change train
This commit is contained in:
parent
c1f0c29026
commit
f8c8bb4e35
|
@ -4,3 +4,4 @@
|
||||||
docs/build/
|
docs/build/
|
||||||
docs/site/
|
docs/site/
|
||||||
deps
|
deps
|
||||||
|
Manifest.toml
|
||||||
|
|
|
@ -56,14 +56,17 @@ function stop()
|
||||||
throw(StopException())
|
throw(StopException())
|
||||||
end
|
end
|
||||||
|
|
||||||
|
maketuple(x) = (x,)
|
||||||
|
maketuple(x::Tuple) = x
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, params, data, opt; cb)
|
train!(loss, params, data, opt; cb)
|
||||||
|
|
||||||
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
|
For each datapoint `d` in `data`, assumed to be a tuple, compute the gradient of `loss(d...)`
|
||||||
backpropagation and call the optimizer `opt`.
|
with respect to `params`, and call the optimizer `opt`.
|
||||||
|
|
||||||
In case datapoints `d` are of numeric array type, assume no splatting is needed
|
If `data` yields a tuple mini-batch `d` under iteration, it will be splatted in the function call
|
||||||
and compute the gradient of `loss(d)`.
|
`loss(d...)`, otherwise `loss(d)` will be called for non-tuple mini-batches.
|
||||||
|
|
||||||
A callback is given with the keyword argument `cb`. For example, this will print
|
A callback is given with the keyword argument `cb`. For example, this will print
|
||||||
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
||||||
|
@ -80,14 +83,8 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
try
|
try
|
||||||
if d isa AbstractArray{<:Number}
|
gs = gradient(ps) do
|
||||||
gs = gradient(ps) do
|
loss(maketuple(d)...)
|
||||||
loss(d)
|
|
||||||
end
|
|
||||||
else
|
|
||||||
gs = gradient(ps) do
|
|
||||||
loss(d...)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
update!(opt, ps, gs)
|
update!(opt, ps, gs)
|
||||||
cb()
|
cb()
|
||||||
|
|
|
@ -2,7 +2,6 @@ using Flux
|
||||||
using Flux.Data
|
using Flux.Data
|
||||||
using Test
|
using Test
|
||||||
using Random, Statistics, LinearAlgebra
|
using Random, Statistics, LinearAlgebra
|
||||||
using Documenter
|
|
||||||
using IterTools: ncycle
|
using IterTools: ncycle
|
||||||
|
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
@ -38,9 +37,10 @@ end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Docs" begin
|
@static if VERSION >= v"1.4"
|
||||||
if VERSION >= v"1.4"
|
using Documenter
|
||||||
|
@testset "Docs" begin
|
||||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||||
doctest(Flux)
|
doctest(Flux)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue