From e51070bf799b90b40e690a0f5dc4ab728cac76bb Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 2 Mar 2020 10:52:27 +0100 Subject: [PATCH 1/4] update documenter --- Manifest.toml | 10 ++--- Project.toml | 1 - docs/Manifest.toml | 89 ---------------------------------------- docs/Project.toml | 3 ++ docs/make.jl | 10 ++--- docs/src/models/nnlib.md | 13 ++++++ 6 files changed, 23 insertions(+), 103 deletions(-) delete mode 100644 docs/Manifest.toml diff --git a/Manifest.toml b/Manifest.toml index 788e5354..dac05aec 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -230,9 +230,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] -git-tree-sha1 = "755c0bab3912ff782167e1b4b774b833f8a0e550" +git-tree-sha1 = "21a3c22bc197b6ae2f8d4d75631876e2b6506dbe" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.4" +version = "0.6.5" [[NaNMath]] git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f" @@ -361,11 +361,9 @@ version = "1.2.11+8" [[Zygote]] deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "3c65158c0aa0808cdfff8bca2a36430b038aad00" -repo-rev = "master" -repo-url = "https://github.com/FluxML/Zygote.jl.git" +git-tree-sha1 = "f8329b595c465caf3ca87c4f744e6041a4983e43" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.4.7" +version = "0.4.8" [[ZygoteRules]] deps = ["MacroTools"] diff --git a/Project.toml b/Project.toml index bd105730..a27d766b 100644 --- a/Project.toml +++ b/Project.toml @@ -44,6 +44,5 @@ IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - [targets] test = ["Test", "Documenter", "IterTools", "LinearAlgebra"] diff --git a/docs/Manifest.toml b/docs/Manifest.toml deleted file mode 100644 index bf9d220a..00000000 --- a/docs/Manifest.toml +++ /dev/null @@ -1,89 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[DocStringExtensions]] -deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "0513f1a8991e9d83255e0140aace0d0fc4486600" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.0" - -[[Documenter]] -deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "c61d6eedbc3c4323c08b64af12d29c8ee0fcbb5f" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.23.2" - -[[InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.0" - -[[LibGit2]] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[Parsers]] -deps = ["Dates", "Test"] -git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "0.3.6" - -[[Pkg]] -deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/docs/Project.toml b/docs/Project.toml index dfa65cd1..1b9ab1f8 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,2 +1,5 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" + +[compat] +Documenter = "0.24" diff --git a/docs/make.jl b/docs/make.jl index e42d8217..03fbf413 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,9 +1,3 @@ -using Pkg; -Pkg.activate(joinpath(@__DIR__, "..")); Pkg.instantiate() -Pkg.activate(); Pkg.instantiate() - -pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) - using Documenter, Flux, NNlib makedocs(modules=[Flux, NNlib], @@ -30,4 +24,6 @@ makedocs(modules=[Flux, NNlib], analytics = "UA-36890222-9", prettyurls = haskey(ENV, "CI"))) -deploydocs(repo = "github.com/FluxML/Flux.jl.git") +deploydocs(repo = "github.com/FluxML/Flux.jl.git", + target = "build", + push_preview = true) diff --git a/docs/src/models/nnlib.md b/docs/src/models/nnlib.md index 9e570cb3..698a95ae 100644 --- a/docs/src/models/nnlib.md +++ b/docs/src/models/nnlib.md @@ -1,7 +1,9 @@ # NNlib + Flux re-exports all of the functions exported by the [NNlib](https://github.com/FluxML/NNlib.jl) package. ## Activation Functions + Non-linearities that go between layers of your model. Note that, unless otherwise stated, activation functions operate on scalars. To apply them to an array you can call `σ.(xs)`, `relu.(xs)` and so on. ```@docs @@ -19,19 +21,30 @@ NNlib.swish ``` ## Softmax + ```@docs NNlib.softmax NNlib.logsoftmax ``` ## Pooling + ```@docs NNlib.maxpool NNlib.meanpool ``` ## Convolution + ```@docs NNlib.conv NNlib.depthwiseconv +``` + +## Batched Operations + +```@docs +NNlib.batched_mul +NNlib.batched_mul! +NNlib.batched_adjoint ``` \ No newline at end of file From ffea8b616dcf0576e09a5f3ec61f0b277570c4b9 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 2 Mar 2020 15:07:50 +0100 Subject: [PATCH 2/4] fix docs --- docs/Manifest.toml | 122 +++++++++++++++++++++++++++++++++++++++++++++ docs/Project.toml | 1 + 2 files changed, 123 insertions(+) create mode 100644 docs/Manifest.toml diff --git a/docs/Manifest.toml b/docs/Manifest.toml new file mode 100644 index 00000000..82d46743 --- /dev/null +++ b/docs/Manifest.toml @@ -0,0 +1,122 @@ +# This file is machine-generated - editing it directly is not advised + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BinaryProvider]] +deps = ["Libdl", "SHA"] +git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.8" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[DocStringExtensions]] +deps = ["LibGit2", "Markdown", "Pkg", "Test"] +git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.1" + +[[Documenter]] +deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] +git-tree-sha1 = "d497bcc45bb98a1fbe19445a774cfafeabc6c6df" +uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +version = "0.24.5" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.0" + +[[LibGit2]] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[NNlib]] +deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] +git-tree-sha1 = "21a3c22bc197b6ae2f8d4d75631876e2b6506dbe" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.6.5" + +[[Parsers]] +deps = ["Dates", "Test"] +git-tree-sha1 = "0c16b3179190d3046c073440d94172cfc3bb0553" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "0.3.12" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.0.1" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/docs/Project.toml b/docs/Project.toml index 1b9ab1f8..670a65be 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,6 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [compat] Documenter = "0.24" From f5da4d0c70140f573ded4233073c336389cf212a Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 2 Mar 2020 15:10:08 +0100 Subject: [PATCH 3/4] remove docs manifest --- docs/Manifest.toml | 122 --------------------------------------------- 1 file changed, 122 deletions(-) delete mode 100644 docs/Manifest.toml diff --git a/docs/Manifest.toml b/docs/Manifest.toml deleted file mode 100644 index 82d46743..00000000 --- a/docs/Manifest.toml +++ /dev/null @@ -1,122 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[BinaryProvider]] -deps = ["Libdl", "SHA"] -git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c" -uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.8" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[DocStringExtensions]] -deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.1" - -[[Documenter]] -deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "d497bcc45bb98a1fbe19445a774cfafeabc6c6df" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.24.5" - -[[InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.0" - -[[LibGit2]] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[NNlib]] -deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] -git-tree-sha1 = "21a3c22bc197b6ae2f8d4d75631876e2b6506dbe" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.5" - -[[Parsers]] -deps = ["Dates", "Test"] -git-tree-sha1 = "0c16b3179190d3046c073440d94172cfc3bb0553" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "0.3.12" - -[[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.0.1" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" From af99ca27eee01101ebeb06425c6b8fab495b7b2c Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Tue, 3 Mar 2020 07:52:20 +0100 Subject: [PATCH 4/4] docs update --- Manifest.toml | 16 +++++++++++----- docs/src/models/nnlib.md | 3 ++- docs/src/performance.md | 17 ++++++++--------- src/data/dataloader.jl | 16 ++++++++++------ 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index dac05aec..04465cae 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -18,6 +18,12 @@ git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "1.0.1" +[[ArrayLayouts]] +deps = ["FillArrays", "LinearAlgebra"] +git-tree-sha1 = "bc779df8d73be70e4e05a63727d3a4dfb4c52b1f" +uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +version = "0.1.5" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -230,9 +236,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] -git-tree-sha1 = "21a3c22bc197b6ae2f8d4d75631876e2b6506dbe" +git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.5" +version = "0.6.6" [[NaNMath]] git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f" @@ -360,10 +366,10 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a" version = "1.2.11+8" [[Zygote]] -deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "f8329b595c465caf3ca87c4f744e6041a4983e43" +deps = ["ArrayLayouts", "DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "7dc5fdb4917ac5a84e199ae654316a01cd4a278b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.4.8" +version = "0.4.9" [[ZygoteRules]] deps = ["MacroTools"] diff --git a/docs/src/models/nnlib.md b/docs/src/models/nnlib.md index 698a95ae..6dbfd4f4 100644 --- a/docs/src/models/nnlib.md +++ b/docs/src/models/nnlib.md @@ -12,9 +12,9 @@ NNlib.gelu NNlib.leakyrelu NNlib.logcosh NNlib.logsigmoid -NNlib.sigmoid NNlib.relu NNlib.selu +NNlib.sigmoid NNlib.softplus NNlib.softsign NNlib.swish @@ -47,4 +47,5 @@ NNlib.depthwiseconv NNlib.batched_mul NNlib.batched_mul! NNlib.batched_adjoint +NNlib.batched_transpose ``` \ No newline at end of file diff --git a/docs/src/performance.md b/docs/src/performance.md index 06a4f690..0af8ef3b 100644 --- a/docs/src/performance.md +++ b/docs/src/performance.md @@ -4,7 +4,7 @@ All the usual [Julia performance tips apply](https://docs.julialang.org/en/v1/ma As always [profiling your code](https://docs.julialang.org/en/v1/manual/profile/#Profiling-1) is generally a useful way of finding bottlenecks. Below follow some Flux specific tips/reminders. -## Don't use more precision than you need. +## Don't use more precision than you need Flux works great with all kinds of number types. But often you do not need to be working with say `Float64` (let alone `BigFloat`). @@ -14,7 +14,8 @@ Which means allocations occur much faster. And you use less memory. -## Make sure your activation and loss functions preserve the type of their inputs +## Preserve inputs' types + Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1), they should also preserve the type of their inputs. @@ -29,24 +30,22 @@ because it results in having to use slow mixed type multiplication in the dense Similar situations can occur in the loss function during backpropagation. Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above), -you will see a large slow-down +you will see a large slow-down. This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals. E.g. the following will have run into the same problem as above: ``` - leaky_tanh(x) = 0.01x + tanh(x) + leaky_tanh(x) = 0.01*x + tanh(x) ``` -While one could change your activation function (e.g. to use `0.01f0x`) to avoid this when ever your inputs change, -the idiomatic (and safe way) is to use `oftype`. - +While one could change the activation function (e.g. to use `0.01f0x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`: ``` - leaky_tanh(x) = oftype(x/1, 0.01)x + tanh(x) + leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x) ``` -## Evaluate batches as Matrices of features, rather than sequences of Vector features +## Evaluate batches as Matrices of features While it can sometimes be tempting to process your observations (feature vectors) one at a time e.g. diff --git a/src/data/dataloader.jl b/src/data/dataloader.jl index 8868a9b0..9da14650 100644 --- a/src/data/dataloader.jl +++ b/src/data/dataloader.jl @@ -23,21 +23,25 @@ dimension. If `shuffle=true`, shuffles the observations each time iterations are re-started. If `partial=false`, drops the last mini-batch if it is smaller than the batchsize. +The original data is preserved as a tuple in the `data` field of the DataLoader. + Example usage: Xtrain = rand(10, 100) - dtrain = DataLoader(Xtrain, batchsize=2) - # iterate over 50 mini-batches - for x in dtrain: + train_loader = DataLoader(Xtrain, batchsize=2) + # iterate over 50 mini-batches of size 2 + for x in train_loader: @assert size(x) == (10, 2) ... end + train_loader.data # original dataset + Xtrain = rand(10, 100) Ytrain = rand(100) - dtrain = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true) + train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true) for epoch in 1:100 - for (x, y) in dtrain: + for (x, y) in train_loader: @assert size(x) == (10, 2) @assert size(y) == (2,) ... @@ -46,7 +50,7 @@ Example usage: # train for 10 epochs using IterTools: ncycle - Flux.train!(loss, ps, ncycle(dtrain, 10), opt) + Flux.train!(loss, ps, ncycle(train_loader, 10), opt) """ function DataLoader(data...; batchsize=1, shuffle=false, partial=true) length(data) > 0 || throw(ArgumentError("Need at least one data input"))