restructure; closes #747
This commit is contained in:
parent
e92da0cf85
commit
17732e7023
|
@ -371,9 +371,11 @@ version = "0.8.3"
|
|||
|
||||
[[Zygote]]
|
||||
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b"
|
||||
git-tree-sha1 = "04384d940b67d604dd393688fa60c1f0175e5faf"
|
||||
repo-rev = "buffer-push"
|
||||
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
version = "0.4.1"
|
||||
version = "0.4.3"
|
||||
|
||||
[[ZygoteRules]]
|
||||
deps = ["MacroTools"]
|
||||
|
|
42
src/utils.jl
42
src/utils.jl
|
@ -103,6 +103,48 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
|
|||
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
|
||||
end
|
||||
|
||||
# Flattening models to weight vectors, and back
|
||||
|
||||
function _restructure(m, xs)
|
||||
i = 0
|
||||
fmap(m) do x
|
||||
x isa AbstractArray || return x
|
||||
x = reshape(xs[i.+(1:length(x))], size(x))
|
||||
i += length(x)
|
||||
return x
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
destructure(m)
|
||||
|
||||
Flatten a model's parameters into a single weight vector.
|
||||
|
||||
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
|
||||
julia> θ, re = destructure(m);
|
||||
|
||||
julia> θ
|
||||
67-element Array{Float32,1}:
|
||||
-0.1407104
|
||||
...
|
||||
|
||||
The second return value `re` allows you to reconstruct the original network after making
|
||||
modifications to the weight vector (for example, with a hypernetwork).
|
||||
|
||||
julia> re(θ .* 2)
|
||||
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||
"""
|
||||
function destructure(m)
|
||||
xs = Zygote.Buffer([])
|
||||
fmap(m) do x
|
||||
x isa AbstractArray && push!(xs, x)
|
||||
return x
|
||||
end
|
||||
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
|
||||
end
|
||||
|
||||
# Other
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue