docs update
This commit is contained in:
parent
f5da4d0c70
commit
af99ca27ee
|
@ -18,6 +18,12 @@ git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5"
|
||||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||||
version = "1.0.1"
|
version = "1.0.1"
|
||||||
|
|
||||||
|
[[ArrayLayouts]]
|
||||||
|
deps = ["FillArrays", "LinearAlgebra"]
|
||||||
|
git-tree-sha1 = "bc779df8d73be70e4e05a63727d3a4dfb4c52b1f"
|
||||||
|
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
|
||||||
|
version = "0.1.5"
|
||||||
|
|
||||||
[[Base64]]
|
[[Base64]]
|
||||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||||
|
|
||||||
|
@ -230,9 +236,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||||
|
|
||||||
[[NNlib]]
|
[[NNlib]]
|
||||||
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
|
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
|
||||||
git-tree-sha1 = "21a3c22bc197b6ae2f8d4d75631876e2b6506dbe"
|
git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
|
||||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
version = "0.6.5"
|
version = "0.6.6"
|
||||||
|
|
||||||
[[NaNMath]]
|
[[NaNMath]]
|
||||||
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
|
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
|
||||||
|
@ -360,10 +366,10 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
|
||||||
version = "1.2.11+8"
|
version = "1.2.11+8"
|
||||||
|
|
||||||
[[Zygote]]
|
[[Zygote]]
|
||||||
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
deps = ["ArrayLayouts", "DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||||
git-tree-sha1 = "f8329b595c465caf3ca87c4f744e6041a4983e43"
|
git-tree-sha1 = "7dc5fdb4917ac5a84e199ae654316a01cd4a278b"
|
||||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
version = "0.4.8"
|
version = "0.4.9"
|
||||||
|
|
||||||
[[ZygoteRules]]
|
[[ZygoteRules]]
|
||||||
deps = ["MacroTools"]
|
deps = ["MacroTools"]
|
||||||
|
|
|
@ -12,9 +12,9 @@ NNlib.gelu
|
||||||
NNlib.leakyrelu
|
NNlib.leakyrelu
|
||||||
NNlib.logcosh
|
NNlib.logcosh
|
||||||
NNlib.logsigmoid
|
NNlib.logsigmoid
|
||||||
NNlib.sigmoid
|
|
||||||
NNlib.relu
|
NNlib.relu
|
||||||
NNlib.selu
|
NNlib.selu
|
||||||
|
NNlib.sigmoid
|
||||||
NNlib.softplus
|
NNlib.softplus
|
||||||
NNlib.softsign
|
NNlib.softsign
|
||||||
NNlib.swish
|
NNlib.swish
|
||||||
|
@ -47,4 +47,5 @@ NNlib.depthwiseconv
|
||||||
NNlib.batched_mul
|
NNlib.batched_mul
|
||||||
NNlib.batched_mul!
|
NNlib.batched_mul!
|
||||||
NNlib.batched_adjoint
|
NNlib.batched_adjoint
|
||||||
|
NNlib.batched_transpose
|
||||||
```
|
```
|
|
@ -4,7 +4,7 @@ All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/ma
|
||||||
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
|
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
|
||||||
Below follow some Flux specific tips/reminders.
|
Below follow some Flux specific tips/reminders.
|
||||||
|
|
||||||
## Don't use more precision than you need.
|
## Don't use more precision than you need
|
||||||
|
|
||||||
Flux works great with all kinds of number types.
|
Flux works great with all kinds of number types.
|
||||||
But often you do not need to be working with say `Float64` (let alone `BigFloat`).
|
But often you do not need to be working with say `Float64` (let alone `BigFloat`).
|
||||||
|
@ -14,7 +14,8 @@ Which means allocations occur much faster.
|
||||||
And you use less memory.
|
And you use less memory.
|
||||||
|
|
||||||
|
|
||||||
## Make sure your activation and loss functions preserve the type of their inputs
|
## Preserve inputs' types
|
||||||
|
|
||||||
Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||||
they should also preserve the type of their inputs.
|
they should also preserve the type of their inputs.
|
||||||
|
|
||||||
|
@ -29,24 +30,22 @@ because it results in having to use slow mixed type multiplication in the dense
|
||||||
Similar situations can occur in the loss function during backpropagation.
|
Similar situations can occur in the loss function during backpropagation.
|
||||||
|
|
||||||
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
|
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
|
||||||
you will see a large slow-down
|
you will see a large slow-down.
|
||||||
|
|
||||||
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
|
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
|
||||||
E.g. the following will have run into the same problem as above:
|
E.g. the following will have run into the same problem as above:
|
||||||
|
|
||||||
```
|
```
|
||||||
leaky_tanh(x) = 0.01x + tanh(x)
|
leaky_tanh(x) = 0.01*x + tanh(x)
|
||||||
```
|
```
|
||||||
|
|
||||||
While one could change your activation function (e.g. to use `0.01f0x`) to avoid this when ever your inputs change,
|
While one could change the activation function (e.g. to use `0.01f0x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
|
||||||
the idiomatic (and safe way) is to use `oftype`.
|
|
||||||
|
|
||||||
```
|
```
|
||||||
leaky_tanh(x) = oftype(x/1, 0.01)x + tanh(x)
|
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Evaluate batches as Matrices of features, rather than sequences of Vector features
|
## Evaluate batches as Matrices of features
|
||||||
|
|
||||||
While it can sometimes be tempting to process your observations (feature vectors) one at a time
|
While it can sometimes be tempting to process your observations (feature vectors) one at a time
|
||||||
e.g.
|
e.g.
|
||||||
|
|
|
@ -23,21 +23,25 @@ dimension.
|
||||||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||||
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
||||||
|
|
||||||
|
The original data is preserved as a tuple in the `data` field of the DataLoader.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
Xtrain = rand(10, 100)
|
Xtrain = rand(10, 100)
|
||||||
dtrain = DataLoader(Xtrain, batchsize=2)
|
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||||
# iterate over 50 mini-batches
|
# iterate over 50 mini-batches of size 2
|
||||||
for x in dtrain:
|
for x in train_loader:
|
||||||
@assert size(x) == (10, 2)
|
@assert size(x) == (10, 2)
|
||||||
...
|
...
|
||||||
end
|
end
|
||||||
|
|
||||||
|
train_loader.data # original dataset
|
||||||
|
|
||||||
Xtrain = rand(10, 100)
|
Xtrain = rand(10, 100)
|
||||||
Ytrain = rand(100)
|
Ytrain = rand(100)
|
||||||
dtrain = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
||||||
for epoch in 1:100
|
for epoch in 1:100
|
||||||
for (x, y) in dtrain:
|
for (x, y) in train_loader:
|
||||||
@assert size(x) == (10, 2)
|
@assert size(x) == (10, 2)
|
||||||
@assert size(y) == (2,)
|
@assert size(y) == (2,)
|
||||||
...
|
...
|
||||||
|
@ -46,7 +50,7 @@ Example usage:
|
||||||
|
|
||||||
# train for 10 epochs
|
# train for 10 epochs
|
||||||
using IterTools: ncycle
|
using IterTools: ncycle
|
||||||
Flux.train!(loss, ps, ncycle(dtrain, 10), opt)
|
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
||||||
"""
|
"""
|
||||||
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
||||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||||
|
|
Loading…
Reference in New Issue