diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 52b92cf7..820e64c9 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -35,10 +35,10 @@ Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x -Base.convert(::Type{<:TrackedArray}, x::TrackedArray) = - error("Not implemented: convert $(typeof(x)) to $T") +Base.convert(::Type{TrackedArray{T,N,A}}, x::TrackedArray) where {T,N,A} = + track(convert, A, x) -Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} = +Base.convert(::Type{TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} = TrackedArray(convert(A, x)) Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index ec57f0d3..183392e6 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -37,7 +37,9 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = - error("Not implemented: convert tracked $S to tracked $T") + track(convert, T, x) + +@grad convert(T, x) = convert(T, data(x)), ȳ -> (nothing, convert(typeof(x), ȳ)) (T::Type{<:TrackedReal})(x::Real) = convert(T, x)