restructure; closes #747
This commit is contained in:
parent
e92da0cf85
commit
17732e7023
|
@ -371,9 +371,11 @@ version = "0.8.3"
|
||||||
|
|
||||||
[[Zygote]]
|
[[Zygote]]
|
||||||
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||||
git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b"
|
git-tree-sha1 = "04384d940b67d604dd393688fa60c1f0175e5faf"
|
||||||
|
repo-rev = "buffer-push"
|
||||||
|
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
version = "0.4.1"
|
version = "0.4.3"
|
||||||
|
|
||||||
[[ZygoteRules]]
|
[[ZygoteRules]]
|
||||||
deps = ["MacroTools"]
|
deps = ["MacroTools"]
|
||||||
|
|
42
src/utils.jl
42
src/utils.jl
|
@ -103,6 +103,48 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
|
||||||
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
|
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Flattening models to weight vectors, and back
|
||||||
|
|
||||||
|
function _restructure(m, xs)
|
||||||
|
i = 0
|
||||||
|
fmap(m) do x
|
||||||
|
x isa AbstractArray || return x
|
||||||
|
x = reshape(xs[i.+(1:length(x))], size(x))
|
||||||
|
i += length(x)
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
destructure(m)
|
||||||
|
|
||||||
|
Flatten a model's parameters into a single weight vector.
|
||||||
|
|
||||||
|
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||||
|
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||||
|
|
||||||
|
julia> θ, re = destructure(m);
|
||||||
|
|
||||||
|
julia> θ
|
||||||
|
67-element Array{Float32,1}:
|
||||||
|
-0.1407104
|
||||||
|
...
|
||||||
|
|
||||||
|
The second return value `re` allows you to reconstruct the original network after making
|
||||||
|
modifications to the weight vector (for example, with a hypernetwork).
|
||||||
|
|
||||||
|
julia> re(θ .* 2)
|
||||||
|
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||||
|
"""
|
||||||
|
function destructure(m)
|
||||||
|
xs = Zygote.Buffer([])
|
||||||
|
fmap(m) do x
|
||||||
|
x isa AbstractArray && push!(xs, x)
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
|
||||||
|
end
|
||||||
|
|
||||||
# Other
|
# Other
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue