change train

This commit is contained in:
CarloLucibello 2020-05-04 14:49:17 +02:00
parent c1f0c29026
commit f8c8bb4e35
3 changed files with 14 additions and 16 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@
docs/build/ docs/build/
docs/site/ docs/site/
deps deps
Manifest.toml

View File

@ -56,14 +56,17 @@ function stop()
throw(StopException()) throw(StopException())
end end
maketuple(x) = (x,)
maketuple(x::Tuple) = x
""" """
train!(loss, params, data, opt; cb) train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through For each datapoint `d` in `data`, assumed to be a tuple, compute the gradient of `loss(d...)`
backpropagation and call the optimizer `opt`. with respect to `params`, and call the optimizer `opt`.
In case datapoints `d` are of numeric array type, assume no splatting is needed If `data` yields a tuple mini-batch `d` under iteration, it will be splatted in the function call
and compute the gradient of `loss(d)`. `loss(d...)`, otherwise `loss(d)` will be called for non-tuple mini-batches.
A callback is given with the keyword argument `cb`. For example, this will print A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)): "training" every 10 seconds (using [`Flux.throttle`](@ref)):
@ -80,14 +83,8 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
@progress for d in data @progress for d in data
try try
if d isa AbstractArray{<:Number} gs = gradient(ps) do
gs = gradient(ps) do loss(maketuple(d)...)
loss(d)
end
else
gs = gradient(ps) do
loss(d...)
end
end end
update!(opt, ps, gs) update!(opt, ps, gs)
cb() cb()

View File

@ -2,7 +2,6 @@ using Flux
using Flux.Data using Flux.Data
using Test using Test
using Random, Statistics, LinearAlgebra using Random, Statistics, LinearAlgebra
using Documenter
using IterTools: ncycle using IterTools: ncycle
Random.seed!(0) Random.seed!(0)
@ -38,9 +37,10 @@ end
end end
end end
@testset "Docs" begin @static if VERSION >= v"1.4"
if VERSION >= v"1.4" using Documenter
@testset "Docs" begin
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux) doctest(Flux)
end end
end end