diff --git a/src/onehot.jl b/src/onehot.jl index d01dc9e1..aea68829 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -22,9 +22,12 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) +import NNlib.adapt + +adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) + @require CuArrays begin import CuArrays: CuArray, cudaconvert - CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(xs.height, CuArrays.cu(xs.data)) Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 74fcb2b8..e218c3ea 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -71,11 +71,10 @@ include("back.jl") include("lib.jl") include("numeric.jl") -using Requires +import NNlib.adapt -@require CuArrays begin - import CuArrays: cu - cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.data), RefValue(cu(grad(xs)))) -end +adapt(T, xs::TrackedArray) = + TrackedArray(xs.f, adapt(T, xs.data), + RefValue(adapt(T, grad(xs)))) end