Merge pull request #986 from FluxML/restructure

Destructure/restructure for models
This commit is contained in:
Dhairya Gandhi 2020-01-13 13:04:48 +05:30 committed by GitHub
commit 370fd978fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 115 additions and 93 deletions

View File

@ -1,8 +1,10 @@
# This file is machine-generated - editing it directly is not advised
[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "0.4.1"
version = "0.5.0"
[[AbstractTrees]]
deps = ["Markdown", "Test"]
@ -19,12 +21,6 @@ version = "1.0.0"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BinDeps]]
deps = ["Compat", "Libdl", "SHA", "URIParser"]
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"
[[BinaryProvider]]
deps = ["Libdl", "SHA"]
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
@ -38,21 +34,21 @@ version = "0.2.0"
[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "6eee47385c81ed3b3f716b745697869c712c2df3"
git-tree-sha1 = "56a813440ac98a1aa64672ab460a1512552211a7"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "2.0.0"
version = "2.1.0"
[[CUDAdrv]]
deps = ["CEnum", "CUDAapi", "Printf"]
git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921"
git-tree-sha1 = "1fce616fa0806c67c133eb1d2f68f0f1a7504665"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "4.0.4"
version = "5.0.1"
[[CUDAnative]]
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
git-tree-sha1 = "93f6c917ab2a9b5bb54f8f738f4ec1a6693cb716"
git-tree-sha1 = "6e11d5c2c91fc623952e94c4fb73f9c4db74795a"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "2.5.5"
version = "2.7.0"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
@ -62,9 +58,9 @@ version = "0.6.0"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "10050a24b09e8e41b951e9976b109871ce98d965"
git-tree-sha1 = "7b62b728a5f3dd6ee3b23910303ccf27e82fad5e"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.8.0"
version = "0.8.1"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"]
@ -78,25 +74,13 @@ git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.2.0"
[[Conda]]
deps = ["JSON", "VersionParsing"]
git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
version = "1.3.0"
[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
git-tree-sha1 = "7e00178b18672ee2cf37244ac2a273b6b0701b04"
git-tree-sha1 = "5203ed37039c74c5eab31e9fcdc40f23c7e943a3"
repo-rev = "master"
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "1.4.7"
version = "1.6.0"
[[DataAPI]]
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
@ -105,9 +89,9 @@ version = "1.1.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10"
git-tree-sha1 = "f784254f428fb8fd7ac15982e5862a38a44523d3"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.6"
version = "0.17.7"
[[Dates]]
deps = ["Printf"]
@ -118,26 +102,32 @@ deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
deps = ["StaticArrays"]
git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.4"
version = "1.0.2"
[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "f734b5f6bc9c909027ef99f6d91d5d9e4b111eed"
git-tree-sha1 = "10dca52cf6d4a62d82528262921daf63b99704a2"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.1.0"
version = "1.0.0"
[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
deps = ["AbstractFFTs", "FFTW_jll", "IntelOpenMP_jll", "Libdl", "LinearAlgebra", "MKL_jll", "Reexport"]
git-tree-sha1 = "109d82fa4b00429f9afcce873e9f746f11f018d3"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.0.1"
version = "1.2.0"
[[FFTW_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "05674f209a6e3387dd103a945b0113eeb64b1a58"
uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a"
version = "3.3.9+3"
[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
@ -152,15 +142,15 @@ version = "0.6.1"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "da46ac97b17793eba44ff366dc6cb70f1238a738"
git-tree-sha1 = "840700059391d36e2498d89c2e82c08f261f2a2a"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.7"
version = "0.10.8"
[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "a0a3b927b1a06e63fb8b91950cc7df340b7d912c"
git-tree-sha1 = "e756da6cee76a5f1436a05827fa8fdf3badc577f"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "2.0.0"
version = "2.0.1"
[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
@ -168,15 +158,15 @@ git-tree-sha1 = "72421971e60917b8cd7737f9577c4f0f87eab306"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.3.0"
[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[IntelOpenMP_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "fb8e1c7a5594ba56f9011310790e03b5384998d6"
uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0"
version = "2018.0.3+0"
[[JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.0"
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
@ -186,9 +176,9 @@ version = "0.7.2"
[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "74fe444b8b6d1ac01d639b2f9eaf395bcc2e24fc"
git-tree-sha1 = "1d08d7e4250f452f6cb20e4574daaebfdbee0ff7"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.3.2"
version = "1.3.3"
[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -203,11 +193,17 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MKL_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "61069ae718b8ab1e325bbfb4e5268902e7ea08e3"
uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
version = "2019.0.117+0"
[[MacroTools]]
deps = ["Compat", "DataStructures", "Test"]
git-tree-sha1 = "82921f0e3bde6aebb8e524efc20f4042373c0c06"
deps = ["DataStructures", "Markdown", "Random"]
git-tree-sha1 = "e2fc7a55bb2224e203bbd8b59f72b91323233458"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.2"
version = "0.5.3"
[[Markdown]]
deps = ["Base64"]
@ -229,30 +225,30 @@ version = "0.4.3"
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
git-tree-sha1 = "135c0de4794d5e214b06f1fb4787af4a72896e61"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.0"
version = "0.6.2"
[[NaNMath]]
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.3"
[[OpenSpecFun_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "65f672edebf3f4e613ddf37db9dcbd7a407e5e90"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+1"
[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"
[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.10"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]
@ -278,10 +274,10 @@ uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
[[Requires]]
deps = ["Test"]
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
deps = ["UUIDs"]
git-tree-sha1 = "999513b7dea8ac17359ed50ae8ea089e4464e35e"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"
version = "1.0.0"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
@ -289,10 +285,6 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
@ -307,10 +299,10 @@ deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl"]
git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
deps = ["OpenSpecFun_jll"]
git-tree-sha1 = "268052ee908b2c086cc0011f528694f02f3e2408"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.8.0"
version = "0.9.0"
[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
@ -344,25 +336,13 @@ git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.5"
[[URIParser]]
deps = ["Test", "Unicode"]
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"
[[UUIDs]]
deps = ["Random"]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[VersionParsing]]
deps = ["Compat"]
git-tree-sha1 = "c9d5aa108588b978bd859554660c8a5c4f2f7669"
uuid = "81def892-9a0e-5fdd-b105-ffc91e053289"
version = "1.1.3"
[[ZipFile]]
deps = ["BinaryProvider", "Libdl", "Printf"]
git-tree-sha1 = "580ce62b6c14244916cc28ad54f8a2e2886f843d"
@ -371,9 +351,9 @@ version = "0.8.3"
[[Zygote]]
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b"
git-tree-sha1 = "7e293d7bef87c2cf2847e99ed0da4edadb75fe90"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.4.1"
version = "0.4.4"
[[ZygoteRules]]
deps = ["MacroTools"]

View File

@ -103,6 +103,48 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
end
# Flattening models to weight vectors, and back
function _restructure(m, xs)
i = 0
fmap(m) do x
x isa AbstractArray || return x
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end
end
"""
destructure(m)
Flatten a model's parameters into a single weight vector.
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> θ, re = destructure(m);
julia> θ
67-element Array{Float32,1}:
-0.1407104
...
The second return value `re` allows you to reconstruct the original network after making
modifications to the weight vector (for example, with a hypernetwork).
julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
"""
function destructure(m)
xs = Zygote.Buffer([])
fmap(m) do x
x isa AbstractArray && push!(xs, x)
return x
end
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
end
# Other
"""