requires update
This commit is contained in:
parent
a49e2eae41
commit
b18b51656c
@ -37,6 +37,6 @@ include("layers/normalise.jl")
|
||||
|
||||
include("data/Data.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
@ -32,7 +32,7 @@ import Adapt.adapt
|
||||
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
@require CuArrays begin
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import CuArrays: CuArray, cudaconvert
|
||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
|
@ -53,7 +53,7 @@ cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
||||
|
||||
gpu_adaptor = identity
|
||||
|
||||
@require CuArrays begin
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
global gpu_adaptor = CuArrays.cu
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user