use Adapt
This commit is contained in:
parent
9cfa459516
commit
468f641f66
1
REQUIRE
1
REQUIRE
@ -5,3 +5,4 @@ MacroTools 0.3.3
|
|||||||
NNlib
|
NNlib
|
||||||
ForwardDiff 0.5.0
|
ForwardDiff 0.5.0
|
||||||
Requires
|
Requires
|
||||||
|
Adapt
|
||||||
|
@ -28,7 +28,7 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
|
|||||||
|
|
||||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||||
|
|
||||||
import NNlib.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ include("back.jl")
|
|||||||
include("lib.jl")
|
include("lib.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
import NNlib.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user