diff --git a/src/treelike.jl b/src/treelike.jl index 4d896771..cac90e4e 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -40,6 +40,14 @@ end params(m...) = params(m) +function loadparams!(m, xs) + for (p, x) in zip(params(m), xs) + size(p) == size(x) || + error("Expected param size $(size(p)), got $(size(x))") + copy!(data(p), data(x)) + end +end + # CPU/GPU movement conveniences cpu(m) = mapleaves(x -> adapt(Array, x), m)