Compare commits
6 Commits
Author | SHA1 | Date |
---|---|---|
CarloLucibello | f8c8bb4e35 | |
CarloLucibello | c1f0c29026 | |
CarloLucibello | 14e7181c7c | |
CarloLucibello | 89191bdeb1 | |
CarloLucibello | c6ba49e8ea | |
CarloLucibello | d77dbc4931 |
|
@ -4,3 +4,4 @@
|
|||
docs/build/
|
||||
docs/site/
|
||||
deps
|
||||
Manifest.toml
|
||||
|
|
367
Manifest.toml
367
Manifest.toml
|
@ -1,367 +0,0 @@
|
|||
# This file is machine-generated - editing it directly is not advised
|
||||
|
||||
[[AbstractFFTs]]
|
||||
deps = ["LinearAlgebra"]
|
||||
git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716"
|
||||
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
|
||||
version = "0.5.0"
|
||||
|
||||
[[AbstractTrees]]
|
||||
deps = ["Markdown"]
|
||||
git-tree-sha1 = "86d092c2599f1f7bb01668bf8eb3412f98d61e47"
|
||||
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
version = "0.3.2"
|
||||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra"]
|
||||
git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "1.0.1"
|
||||
|
||||
[[ArrayLayouts]]
|
||||
deps = ["FillArrays", "LinearAlgebra"]
|
||||
git-tree-sha1 = "41956a49a8a4fefa1bf6664bca4a3035aba4c3a0"
|
||||
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
|
||||
version = "0.2.3"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
||||
[[BinaryProvider]]
|
||||
deps = ["Libdl", "SHA"]
|
||||
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
|
||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.8"
|
||||
|
||||
[[CEnum]]
|
||||
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
|
||||
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
|
||||
version = "0.2.0"
|
||||
|
||||
[[CUDAapi]]
|
||||
deps = ["Libdl", "Logging"]
|
||||
git-tree-sha1 = "831b825d10104bd29e28f6da93312a976830717b"
|
||||
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
||||
version = "4.0.0"
|
||||
|
||||
[[CUDAdrv]]
|
||||
deps = ["CEnum", "CUDAapi", "Printf"]
|
||||
git-tree-sha1 = "e650cbaee92b60433313157926b1e80d0c3a0e2e"
|
||||
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
||||
version = "6.2.2"
|
||||
|
||||
[[CUDAnative]]
|
||||
deps = ["Adapt", "BinaryProvider", "CEnum", "CUDAapi", "CUDAdrv", "Cthulhu", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "MacroTools", "Pkg", "Printf", "TimerOutputs"]
|
||||
git-tree-sha1 = "d1fc99635d0002c8a819b78cb1f441eb44310725"
|
||||
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
|
||||
version = "3.0.2"
|
||||
|
||||
[[CodeTracking]]
|
||||
deps = ["InteractiveUtils", "UUIDs"]
|
||||
git-tree-sha1 = "0becdab7e6fbbcb7b88d8de5b72e5bb2f28239f3"
|
||||
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
|
||||
version = "0.5.8"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["TranscodingStreams", "Zlib_jll"]
|
||||
git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.7.0"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random"]
|
||||
git-tree-sha1 = "c4c1cca28748906265ed62c788d6fe6f0134d264"
|
||||
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
|
||||
version = "0.10.0"
|
||||
|
||||
[[Colors]]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"]
|
||||
git-tree-sha1 = "2fdeb981ebcf52cd800ddb6a0aa5eac34153552d"
|
||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
version = "0.12.0"
|
||||
|
||||
[[CommonSubexpressions]]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
|
||||
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
|
||||
version = "0.2.0"
|
||||
|
||||
[[CompilerSupportLibraries_jll]]
|
||||
deps = ["Libdl", "Pkg"]
|
||||
git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612"
|
||||
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
|
||||
version = "0.3.3+0"
|
||||
|
||||
[[Cthulhu]]
|
||||
deps = ["CodeTracking", "InteractiveUtils", "REPL", "Unicode"]
|
||||
git-tree-sha1 = "484790098c85c26f8e59051f8ff1a0745c034a7d"
|
||||
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
|
||||
version = "1.0.1"
|
||||
|
||||
[[CuArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
|
||||
git-tree-sha1 = "e8c55b38dcca955f5aed8ec4479cdc95810db1e1"
|
||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
version = "2.0.1"
|
||||
|
||||
[[DataAPI]]
|
||||
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
|
||||
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
|
||||
version = "1.1.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections"]
|
||||
git-tree-sha1 = "73eb18320fe3ba58790c8b8f6f89420f0a622773"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.17.11"
|
||||
|
||||
[[Dates]]
|
||||
deps = ["Printf"]
|
||||
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
||||
|
||||
[[DelimitedFiles]]
|
||||
deps = ["Mmap"]
|
||||
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
|
||||
[[DiffResults]]
|
||||
deps = ["StaticArrays"]
|
||||
git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc"
|
||||
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
|
||||
version = "1.0.2"
|
||||
|
||||
[[DiffRules]]
|
||||
deps = ["NaNMath", "Random", "SpecialFunctions"]
|
||||
git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1"
|
||||
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
|
||||
version = "1.0.1"
|
||||
|
||||
[[Distributed]]
|
||||
deps = ["Random", "Serialization", "Sockets"]
|
||||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[FillArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
||||
git-tree-sha1 = "51cc2f9bc4eb9c6c0e81ec2f779d1085583cc956"
|
||||
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
||||
version = "0.8.7"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
git-tree-sha1 = "3ba9ea634d4c8b289d590403b4a06f8e227a6238"
|
||||
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
||||
version = "0.8.0"
|
||||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
|
||||
git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.10"
|
||||
|
||||
[[GPUArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
|
||||
git-tree-sha1 = "d586762b08dcda13228df8967119b9cb6f22ade5"
|
||||
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
||||
version = "3.1.0"
|
||||
|
||||
[[IRTools]]
|
||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||
git-tree-sha1 = "1a4355e4b5b50be2311ebb644f34f3306dbd0410"
|
||||
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||
version = "0.3.1"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile"]
|
||||
git-tree-sha1 = "e1ba2a612645b3e07c773c3a208f215745081fe6"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.8.1"
|
||||
|
||||
[[LLVM]]
|
||||
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
|
||||
git-tree-sha1 = "b6b86801ae2f2682e0a4889315dc76b68db2de71"
|
||||
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
|
||||
version = "1.3.4"
|
||||
|
||||
[[LibGit2]]
|
||||
deps = ["Printf"]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
||||
[[Libdl]]
|
||||
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
|
||||
|
||||
[[LinearAlgebra]]
|
||||
deps = ["Libdl"]
|
||||
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
|
||||
[[Logging]]
|
||||
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||
|
||||
[[MacroTools]]
|
||||
deps = ["Markdown", "Random"]
|
||||
git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a"
|
||||
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
version = "0.5.5"
|
||||
|
||||
[[Markdown]]
|
||||
deps = ["Base64"]
|
||||
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
|
||||
|
||||
[[Media]]
|
||||
deps = ["MacroTools", "Test"]
|
||||
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
|
||||
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Missings]]
|
||||
deps = ["DataAPI"]
|
||||
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
|
||||
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
|
||||
version = "0.4.3"
|
||||
|
||||
[[Mmap]]
|
||||
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
|
||||
|
||||
[[NNlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
|
||||
git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
|
||||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
version = "0.6.6"
|
||||
|
||||
[[NaNMath]]
|
||||
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
|
||||
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
||||
version = "0.3.3"
|
||||
|
||||
[[OpenSpecFun_jll]]
|
||||
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
|
||||
git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87"
|
||||
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
|
||||
version = "0.5.3+3"
|
||||
|
||||
[[OrderedCollections]]
|
||||
deps = ["Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
|
||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||
version = "1.1.0"
|
||||
|
||||
[[Pkg]]
|
||||
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
|
||||
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
|
||||
[[Printf]]
|
||||
deps = ["Unicode"]
|
||||
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
|
||||
[[Profile]]
|
||||
deps = ["Printf"]
|
||||
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
|
||||
|
||||
[[REPL]]
|
||||
deps = ["InteractiveUtils", "Markdown", "Sockets"]
|
||||
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
|
||||
|
||||
[[Random]]
|
||||
deps = ["Serialization"]
|
||||
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
|
||||
[[Reexport]]
|
||||
deps = ["Pkg"]
|
||||
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
|
||||
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
version = "0.2.0"
|
||||
|
||||
[[Requires]]
|
||||
deps = ["UUIDs"]
|
||||
git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b"
|
||||
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
version = "1.0.1"
|
||||
|
||||
[[SHA]]
|
||||
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
|
||||
[[Serialization]]
|
||||
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
|
||||
|
||||
[[Sockets]]
|
||||
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
|
||||
|
||||
[[SortingAlgorithms]]
|
||||
deps = ["DataStructures", "Random", "Test"]
|
||||
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
|
||||
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
|
||||
version = "0.3.1"
|
||||
|
||||
[[SparseArrays]]
|
||||
deps = ["LinearAlgebra", "Random"]
|
||||
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
||||
|
||||
[[SpecialFunctions]]
|
||||
deps = ["OpenSpecFun_jll"]
|
||||
git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c"
|
||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
version = "0.10.0"
|
||||
|
||||
[[StaticArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "Statistics"]
|
||||
git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.12.1"
|
||||
|
||||
[[Statistics]]
|
||||
deps = ["LinearAlgebra", "SparseArrays"]
|
||||
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[[StatsBase]]
|
||||
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
|
||||
git-tree-sha1 = "a6102b1f364befdb05746f386b67c6b7e3262c45"
|
||||
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
version = "0.33.0"
|
||||
|
||||
[[Test]]
|
||||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TimerOutputs]]
|
||||
deps = ["Printf"]
|
||||
git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03"
|
||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||
version = "0.5.3"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Random", "Test"]
|
||||
git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.9.5"
|
||||
|
||||
[[UUIDs]]
|
||||
deps = ["Random", "SHA"]
|
||||
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
|
||||
|
||||
[[Unicode]]
|
||||
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
||||
|
||||
[[ZipFile]]
|
||||
deps = ["Libdl", "Printf", "Zlib_jll"]
|
||||
git-tree-sha1 = "8748302cfdec02c4ae9c97b112cf10003f7f767f"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.9.1"
|
||||
|
||||
[[Zlib_jll]]
|
||||
deps = ["Libdl", "Pkg"]
|
||||
git-tree-sha1 = "2f6c3e15e20e036ee0a0965879b31442b7ec50fa"
|
||||
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
|
||||
version = "1.2.11+9"
|
||||
|
||||
[[Zygote]]
|
||||
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
git-tree-sha1 = "1ccbfbe8930376e31752b812daa2532c723dc332"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
version = "0.4.13"
|
||||
|
||||
[[ZygoteRules]]
|
||||
deps = ["MacroTools"]
|
||||
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
|
||||
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
|
||||
version = "0.2.0"
|
|
@ -1,6 +1,6 @@
|
|||
name = "Flux"
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.10.4"
|
||||
version = "0.11.0"
|
||||
|
||||
[deps]
|
||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
|
|
|
@ -51,4 +51,6 @@ export Iris
|
|||
include("housing.jl")
|
||||
export Housing
|
||||
|
||||
@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)
|
||||
|
||||
end
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
|
||||
|
||||
struct DataLoader
|
||||
data
|
||||
struct DataLoader{D}
|
||||
data::D
|
||||
batchsize::Int
|
||||
nobs::Int
|
||||
partial::Bool
|
||||
|
@ -11,21 +11,20 @@ struct DataLoader
|
|||
end
|
||||
|
||||
"""
|
||||
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
||||
DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||
|
||||
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
||||
(except possibly the last one).
|
||||
|
||||
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
|
||||
supervised learning. The last dimension in each tensor is considered to be the observation
|
||||
dimension.
|
||||
Takes as input a data tensors or a tuple of one or more such tensors.
|
||||
The last dimension in each tensor is considered to be the observation dimension.
|
||||
|
||||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
||||
|
||||
The original data is preserved as a tuple in the `data` field of the DataLoader.
|
||||
The original data is preserved in the `data` field of the DataLoader.
|
||||
|
||||
Example usage:
|
||||
Usage example:
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||
|
@ -37,9 +36,16 @@ Example usage:
|
|||
|
||||
train_loader.data # original dataset
|
||||
|
||||
# similar, but yielding tuples
|
||||
train_loader = DataLoader((Xtrain,), batchsize=2)
|
||||
for (x,) in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
...
|
||||
end
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
Ytrain = rand(100)
|
||||
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
||||
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
|
||||
for epoch in 1:100
|
||||
for (x, y) in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
|
@ -52,25 +58,19 @@ Example usage:
|
|||
using IterTools: ncycle
|
||||
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
||||
"""
|
||||
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
|
||||
|
||||
nx = size(data[1])[end]
|
||||
for i=2:length(data)
|
||||
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
|
||||
n = _nobs(data)
|
||||
if n < batchsize
|
||||
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
|
||||
batchsize = n
|
||||
end
|
||||
if nx < batchsize
|
||||
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
|
||||
batchsize = nx
|
||||
end
|
||||
imax = partial ? nx : nx - batchsize + 1
|
||||
ids = 1:min(nx, batchsize)
|
||||
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
|
||||
imax = partial ? n : n - batchsize + 1
|
||||
ids = 1:min(n, batchsize)
|
||||
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
|
||||
end
|
||||
|
||||
getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
|
||||
|
||||
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
|
||||
i >= d.imax && return nothing
|
||||
if d.shuffle && i == 0
|
||||
|
@ -78,11 +78,7 @@ getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
|
|||
end
|
||||
nexti = min(i + d.batchsize, d.nobs)
|
||||
ids = d.indices[i+1:nexti]
|
||||
if length(d.data) == 1
|
||||
batch = getdata(d.data[1], ids)
|
||||
else
|
||||
batch = ((getdata(x, ids) for x in d.data)...,)
|
||||
end
|
||||
batch = _getobs(d.data, ids)
|
||||
return (batch, nexti)
|
||||
end
|
||||
|
||||
|
@ -90,3 +86,22 @@ function Base.length(d::DataLoader)
|
|||
n = d.nobs / d.batchsize
|
||||
d.partial ? ceil(Int,n) : floor(Int,n)
|
||||
end
|
||||
|
||||
_nobs(data::AbstractArray) = size(data)[end]
|
||||
|
||||
function _nobs(data::Tuple)
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
n = _nobs(data[1])
|
||||
if !all(x -> _nobs(x) == n, data[2:end])
|
||||
throw(DimensionMismatch("All data should contain same number of observations"))
|
||||
end
|
||||
return n
|
||||
end
|
||||
|
||||
function _getobs(data::A, i) where A<:AbstractArray{T,N} where {T,N}
|
||||
getindex(data, ntuple(i->Colon(), N-1)..., i)
|
||||
end
|
||||
|
||||
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,)
|
||||
|
||||
Base.eltype(d::DataLoader{D}) where D = D
|
|
@ -56,14 +56,17 @@ function stop()
|
|||
throw(StopException())
|
||||
end
|
||||
|
||||
maketuple(x) = (x,)
|
||||
maketuple(x::Tuple) = x
|
||||
|
||||
"""
|
||||
train!(loss, params, data, opt; cb)
|
||||
|
||||
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
|
||||
backpropagation and call the optimizer `opt`.
|
||||
For each datapoint `d` in `data`, assumed to be a tuple, compute the gradient of `loss(d...)`
|
||||
with respect to `params`, and call the optimizer `opt`.
|
||||
|
||||
In case datapoints `d` are of numeric array type, assume no splatting is needed
|
||||
and compute the gradient of `loss(d)`.
|
||||
If `data` yields a tuple mini-batch `d` under iteration, it will be splatted in the function call
|
||||
`loss(d...)`, otherwise `loss(d)` will be called for non-tuple mini-batches.
|
||||
|
||||
A callback is given with the keyword argument `cb`. For example, this will print
|
||||
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
||||
|
@ -80,14 +83,8 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||
cb = runall(cb)
|
||||
@progress for d in data
|
||||
try
|
||||
if d isa AbstractArray{<:Number}
|
||||
gs = gradient(ps) do
|
||||
loss(d)
|
||||
end
|
||||
else
|
||||
gs = gradient(ps) do
|
||||
loss(d...)
|
||||
end
|
||||
gs = gradient(ps) do
|
||||
loss(maketuple(d)...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
cb()
|
||||
|
|
14
test/data.jl
14
test/data.jl
|
@ -4,6 +4,7 @@
|
|||
|
||||
d = DataLoader(X, batchsize=2)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 3
|
||||
@test batches[1] == X[:,1:2]
|
||||
@test batches[2] == X[:,3:4]
|
||||
|
@ -11,12 +12,21 @@
|
|||
|
||||
d = DataLoader(X, batchsize=2, partial=false)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 2
|
||||
@test batches[1] == X[:,1:2]
|
||||
@test batches[2] == X[:,3:4]
|
||||
|
||||
d = DataLoader(X, Y, batchsize=2)
|
||||
d = DataLoader((X,), batchsize=2, partial=false)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
||||
@test length(batches) == 2
|
||||
@test batches[1] == (X[:,1:2],)
|
||||
@test batches[2] == (X[:,3:4],)
|
||||
|
||||
d = DataLoader((X, Y), batchsize=2)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
||||
@test length(batches) == 3
|
||||
@test length(batches[1]) == 2
|
||||
@test length(batches[2]) == 2
|
||||
|
@ -41,7 +51,7 @@
|
|||
X = ones(2, 10)
|
||||
Y = fill(2, 10)
|
||||
loss(x, y) = sum((y - x'*θ).^2)
|
||||
d = DataLoader(X, Y)
|
||||
d = DataLoader((X, Y))
|
||||
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
|
||||
@test norm(θ .- 1) < 1e-10
|
||||
end
|
||||
|
|
|
@ -2,49 +2,45 @@ using Flux
|
|||
using Flux.Data
|
||||
using Test
|
||||
using Random, Statistics, LinearAlgebra
|
||||
using Documenter
|
||||
using IterTools: ncycle
|
||||
|
||||
Random.seed!(0)
|
||||
|
||||
@testset "Flux" begin
|
||||
@testset "Utils" begin
|
||||
include("utils.jl")
|
||||
end
|
||||
|
||||
@testset "Utils" begin
|
||||
include("utils.jl")
|
||||
end
|
||||
|
||||
@testset "Onehot" begin
|
||||
include("onehot.jl")
|
||||
end
|
||||
|
||||
@testset "Optimise" begin
|
||||
include("optimise.jl")
|
||||
end
|
||||
|
||||
@testset "Data" begin
|
||||
include("data.jl")
|
||||
end
|
||||
|
||||
@testset "Layers" begin
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
end
|
||||
|
||||
@testset "CUDA" begin
|
||||
if Flux.use_cuda[]
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
end
|
||||
@testset "Onehot" begin
|
||||
include("onehot.jl")
|
||||
end
|
||||
|
||||
@testset "Optimise" begin
|
||||
include("optimise.jl")
|
||||
end
|
||||
|
||||
@testset "Data" begin
|
||||
include("data.jl")
|
||||
end
|
||||
|
||||
@testset "Layers" begin
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
end
|
||||
|
||||
@testset "CUDA" begin
|
||||
if Flux.use_cuda[]
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
end
|
||||
end
|
||||
|
||||
@static if VERSION >= v"1.4"
|
||||
using Documenter
|
||||
@testset "Docs" begin
|
||||
if VERSION >= v"1.4"
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
doctest(Flux)
|
||||
end
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
doctest(Flux)
|
||||
end
|
||||
|
||||
end # testset Flux
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue