more general adaptors
This commit is contained in:
parent
2b95aff158
commit
1abc4febe6
@ -22,9 +22,12 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
|||||||
|
|
||||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||||
|
|
||||||
|
import NNlib.adapt
|
||||||
|
|
||||||
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
@require CuArrays begin
|
@require CuArrays begin
|
||||||
import CuArrays: CuArray, cudaconvert
|
import CuArrays: CuArray, cudaconvert
|
||||||
CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(xs.height, CuArrays.cu(xs.data))
|
|
||||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
@ -71,11 +71,10 @@ include("back.jl")
|
|||||||
include("lib.jl")
|
include("lib.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
using Requires
|
import NNlib.adapt
|
||||||
|
|
||||||
@require CuArrays begin
|
adapt(T, xs::TrackedArray) =
|
||||||
import CuArrays: cu
|
TrackedArray(xs.f, adapt(T, xs.data),
|
||||||
cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.data), RefValue(cu(grad(xs))))
|
RefValue(adapt(T, grad(xs))))
|
||||||
end
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user