diff --git a/src/treelike.jl b/src/treelike.jl index cc7d827a..238eb02c 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -45,10 +45,10 @@ params(m...) = params(m) cpu(x) = adapt(Array, x) -default_adaptor = Array +default_adaptor = identity @require CuArrays begin - global default_adaptor = CuArray + global default_adaptor = CuArrays.cu end gpu(x) = adapt(default_adaptor, x)