added training api changes
This commit is contained in:
parent
1ea8c5a293
commit
d412845192
@ -24,9 +24,10 @@ m = Chain(
|
|||||||
Dense(32, 10), softmax)
|
Dense(32, 10), softmax)
|
||||||
|
|
||||||
loss(x, y) = Flux.mse(m(x), y)
|
loss(x, y) = Flux.mse(m(x), y)
|
||||||
|
ps = Flux.params(m)
|
||||||
|
|
||||||
# later
|
# later
|
||||||
Flux.train!(loss, params, data, opt)
|
Flux.train!(loss, ps, data, opt)
|
||||||
```
|
```
|
||||||
|
|
||||||
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
||||||
@ -78,7 +79,7 @@ julia> @epochs 2 Flux.train!(...)
|
|||||||
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
|
`train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
train!(objective, params, data, opt, cb = () -> println("training"))
|
train!(objective, ps, data, opt, cb = () -> println("training"))
|
||||||
```
|
```
|
||||||
|
|
||||||
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
|
Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds.
|
||||||
@ -89,6 +90,6 @@ A more typical callback might look like this:
|
|||||||
test_x, test_y = # ... create single batch of test data ...
|
test_x, test_y = # ... create single batch of test data ...
|
||||||
evalcb() = @show(loss(test_x, test_y))
|
evalcb() = @show(loss(test_x, test_y))
|
||||||
|
|
||||||
Flux.train!(objective, data, opt,
|
Flux.train!(objective, ps, data, opt,
|
||||||
cb = throttle(evalcb, 5))
|
cb = throttle(evalcb, 5))
|
||||||
```
|
```
|
||||||
|
Loading…
Reference in New Issue
Block a user