change train

This commit is contained in:
CarloLucibello 2020-05-04 14:49:17 +02:00
parent c1f0c29026
commit f8c8bb4e35
3 changed files with 14 additions and 16 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@
docs/build/
docs/site/
deps
Manifest.toml

View File

@ -56,14 +56,17 @@ function stop()
throw(StopException())
end
maketuple(x) = (x,)
maketuple(x::Tuple) = x
"""
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
backpropagation and call the optimizer `opt`.
For each datapoint `d` in `data`, assumed to be a tuple, compute the gradient of `loss(d...)`
with respect to `params`, and call the optimizer `opt`.
In case datapoints `d` are of numeric array type, assume no splatting is needed
and compute the gradient of `loss(d)`.
If `data` yields a tuple mini-batch `d` under iteration, it will be splatted in the function call
`loss(d...)`, otherwise `loss(d)` will be called for non-tuple mini-batches.
A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
@ -80,14 +83,8 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
if d isa AbstractArray{<:Number}
gs = gradient(ps) do
loss(d)
end
else
gs = gradient(ps) do
loss(d...)
end
gs = gradient(ps) do
loss(maketuple(d)...)
end
update!(opt, ps, gs)
cb()

View File

@ -2,7 +2,6 @@ using Flux
using Flux.Data
using Test
using Random, Statistics, LinearAlgebra
using Documenter
using IterTools: ncycle
Random.seed!(0)
@ -38,9 +37,10 @@ end
end
end
@testset "Docs" begin
if VERSION >= v"1.4"
@static if VERSION >= v"1.4"
using Documenter
@testset "Docs" begin
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
end
end
end