docs update
This commit is contained in:
parent
f5da4d0c70
commit
af99ca27ee
|
@ -18,6 +18,12 @@ git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5"
|
|||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "1.0.1"
|
||||
|
||||
[[ArrayLayouts]]
|
||||
deps = ["FillArrays", "LinearAlgebra"]
|
||||
git-tree-sha1 = "bc779df8d73be70e4e05a63727d3a4dfb4c52b1f"
|
||||
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
|
||||
version = "0.1.5"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
||||
|
@ -230,9 +236,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
|||
|
||||
[[NNlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
|
||||
git-tree-sha1 = "21a3c22bc197b6ae2f8d4d75631876e2b6506dbe"
|
||||
git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.6.5"
|
||||
version = "0.6.6"
|
||||
|
||||
[[NaNMath]]
|
||||
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
|
||||
|
@ -360,10 +366,10 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
|
|||
version = "1.2.11+8"
|
||||
|
||||
[[Zygote]]
|
||||
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
git-tree-sha1 = "f8329b595c465caf3ca87c4f744e6041a4983e43"
|
||||
deps = ["ArrayLayouts", "DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
git-tree-sha1 = "7dc5fdb4917ac5a84e199ae654316a01cd4a278b"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
version = "0.4.8"
|
||||
version = "0.4.9"
|
||||
|
||||
[[ZygoteRules]]
|
||||
deps = ["MacroTools"]
|
||||
|
|
|
@ -12,9 +12,9 @@ NNlib.gelu
|
|||
NNlib.leakyrelu
|
||||
NNlib.logcosh
|
||||
NNlib.logsigmoid
|
||||
NNlib.sigmoid
|
||||
NNlib.relu
|
||||
NNlib.selu
|
||||
NNlib.sigmoid
|
||||
NNlib.softplus
|
||||
NNlib.softsign
|
||||
NNlib.swish
|
||||
|
@ -47,4 +47,5 @@ NNlib.depthwiseconv
|
|||
NNlib.batched_mul
|
||||
NNlib.batched_mul!
|
||||
NNlib.batched_adjoint
|
||||
NNlib.batched_transpose
|
||||
```
|
|
@ -4,7 +4,7 @@ All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/ma
|
|||
As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks.
|
||||
Below follow some Flux specific tips/reminders.
|
||||
|
||||
## Don't use more precision than you need.
|
||||
## Don't use more precision than you need
|
||||
|
||||
Flux works great with all kinds of number types.
|
||||
But often you do not need to be working with say `Float64` (let alone `BigFloat`).
|
||||
|
@ -14,7 +14,8 @@ Which means allocations occur much faster.
|
|||
And you use less memory.
|
||||
|
||||
|
||||
## Make sure your activation and loss functions preserve the type of their inputs
|
||||
## Preserve inputs' types
|
||||
|
||||
Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||
they should also preserve the type of their inputs.
|
||||
|
||||
|
@ -29,24 +30,22 @@ because it results in having to use slow mixed type multiplication in the dense
|
|||
Similar situations can occur in the loss function during backpropagation.
|
||||
|
||||
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
|
||||
you will see a large slow-down
|
||||
you will see a large slow-down.
|
||||
|
||||
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
|
||||
E.g. the following will have run into the same problem as above:
|
||||
|
||||
```
|
||||
leaky_tanh(x) = 0.01x + tanh(x)
|
||||
leaky_tanh(x) = 0.01*x + tanh(x)
|
||||
```
|
||||
|
||||
While one could change your activation function (e.g. to use `0.01f0x`) to avoid this when ever your inputs change,
|
||||
the idiomatic (and safe way) is to use `oftype`.
|
||||
|
||||
While one could change the activation function (e.g. to use `0.01f0x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
|
||||
```
|
||||
leaky_tanh(x) = oftype(x/1, 0.01)x + tanh(x)
|
||||
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
|
||||
```
|
||||
|
||||
|
||||
## Evaluate batches as Matrices of features, rather than sequences of Vector features
|
||||
## Evaluate batches as Matrices of features
|
||||
|
||||
While it can sometimes be tempting to process your observations (feature vectors) one at a time
|
||||
e.g.
|
||||
|
|
|
@ -23,21 +23,25 @@ dimension.
|
|||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
||||
|
||||
The original data is preserved as a tuple in the `data` field of the DataLoader.
|
||||
|
||||
Example usage:
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
dtrain = DataLoader(Xtrain, batchsize=2)
|
||||
# iterate over 50 mini-batches
|
||||
for x in dtrain:
|
||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||
# iterate over 50 mini-batches of size 2
|
||||
for x in train_loader:
|
||||
@assert size(x) == (10, 2)
|
||||
...
|
||||
end
|
||||
|
||||
train_loader.data # original dataset
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
Ytrain = rand(100)
|
||||
dtrain = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
||||
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
||||
for epoch in 1:100
|
||||
for (x, y) in dtrain:
|
||||
for (x, y) in train_loader:
|
||||
@assert size(x) == (10, 2)
|
||||
@assert size(y) == (2,)
|
||||
...
|
||||
|
@ -46,7 +50,7 @@ Example usage:
|
|||
|
||||
# train for 10 epochs
|
||||
using IterTools: ncycle
|
||||
Flux.train!(loss, ps, ncycle(dtrain, 10), opt)
|
||||
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
||||
"""
|
||||
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
|
|
Loading…
Reference in New Issue