use Adapt

This commit is contained in:
Mike J Innes 2018-01-08 16:31:23 +00:00
parent 9cfa459516
commit 468f641f66
3 changed files with 3 additions and 2 deletions

View File

@ -5,3 +5,4 @@ MacroTools 0.3.3
NNlib
ForwardDiff 0.5.0
Requires
Adapt

View File

@ -28,7 +28,7 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
import NNlib.adapt
import Adapt.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))

View File

@ -93,7 +93,7 @@ include("back.jl")
include("lib.jl")
include("numeric.jl")
import NNlib.adapt
import Adapt.adapt
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))