diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 3d9836d0..b8b06471 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -43,6 +43,8 @@ end Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x)) +Base.copy(x::TrackedArray) = copy(data(x)) + Base.setindex!(xs::TrackedArray, v, i...) = error("Can't differentiate `setindex!`") diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index ad7b643d..e0ae7db1 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -23,6 +23,8 @@ end Base.decompose(x::TrackedReal) = Base.decompose(data(x)) +Base.copy(x::TrackedArray) = copy(data(x)) + Base.convert(::Type{T}, x::TrackedReal{S}) where {T<:Real,S} = convert(T, data(x)) Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x