From 1abc4febe67ff04b7f3decc77a8fa207189f8026 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 4 Oct 2017 18:55:56 +0100 Subject: [PATCH] more general adaptors --- src/onehot.jl | 5 ++++- src/tracker/Tracker.jl | 9 ++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index d01dc9e1..aea68829 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -22,9 +22,12 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) +import NNlib.adapt + +adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) + @require CuArrays begin import CuArrays: CuArray, cudaconvert - CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(xs.height, CuArrays.cu(xs.data)) Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 74fcb2b8..e218c3ea 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -71,11 +71,10 @@ include("back.jl") include("lib.jl") include("numeric.jl") -using Requires +import NNlib.adapt -@require CuArrays begin - import CuArrays: cu - cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.data), RefValue(cu(grad(xs)))) -end +adapt(T, xs::TrackedArray) = + TrackedArray(xs.f, adapt(T, xs.data), + RefValue(adapt(T, grad(xs)))) end