requires update
This commit is contained in:
parent
a49e2eae41
commit
b18b51656c
@ -37,6 +37,6 @@ include("layers/normalise.jl")
|
|||||||
|
|
||||||
include("data/Data.jl")
|
include("data/Data.jl")
|
||||||
|
|
||||||
@require CuArrays include("cuda/cuda.jl")
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -32,7 +32,7 @@ import Adapt.adapt
|
|||||||
|
|
||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
@require CuArrays begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
import CuArrays: CuArray, cudaconvert
|
import CuArrays: CuArray, cudaconvert
|
||||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
|
@ -53,7 +53,7 @@ cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
|||||||
|
|
||||||
gpu_adaptor = identity
|
gpu_adaptor = identity
|
||||||
|
|
||||||
@require CuArrays begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
global gpu_adaptor = CuArrays.cu
|
global gpu_adaptor = CuArrays.cu
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user