From 97244e0a68fa8cbae17f8065160126897a674009 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Sat, 4 Nov 2017 13:27:32 +0100 Subject: [PATCH] Allow array of optimisers to train! This allows an array of optimisers to be sent to `train!` --- src/optimise/train.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 2a2ec5eb..0809e86b 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,8 +1,8 @@ using Juno using Flux.Tracker: back! -tocb(f) = f -tocb(fs::AbstractVector) = () -> foreach(call, fs) +runall(f) = f +runall(fs::AbstractVector) = () -> foreach(call, fs) """ train!(loss, data, opt; cb = () -> ()) @@ -11,10 +11,11 @@ For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt` and the callback `cb` (i.e. `opt()` and `cb()`). -Multiple callbacks can be passed to `cb` as an array. +Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, data, opt; cb = () -> ()) - cb = tocb(cb) + cb = runall(cb) + opt = runall(opt) @progress for d in data l = loss(d...) isinf(l.data[]) && error("Loss is Inf")