diff --git a/src/model.jl b/src/model.jl index ca52154e..d333d9b2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -48,6 +48,7 @@ graph(m) = nothing # Model parameters +# TODO: should be AbstractArray? """ A `Param` object stores a parameter array along with an accumulated delta to that array. When converting to backends like TensorFlow, identical `Param`s will @@ -99,6 +100,9 @@ function Base.show(io::IO, p::Param) print(io, "Param", size(p.x)) end +Base.copy!(xs, p::Param) = copy!(xs, p.x) +Base.copy!(p::Param, xs) = copy!(p.x, xs) + # Anonymous models export Capacitor