cpu/gpu conveniences
This commit is contained in:
parent
15d1d3256b
commit
466b5c501a
@ -10,7 +10,7 @@ using MacroTools: @forward
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||
param, params, mapleaves
|
||||
param, params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
|
||||
|
@ -40,3 +40,15 @@ function params(m)
|
||||
end
|
||||
|
||||
params(m...) = params(m)
|
||||
|
||||
# CPU/GPU movement conveniences
|
||||
|
||||
cpu(x) = adapt(Array, x)
|
||||
|
||||
default_adaptor = Array
|
||||
|
||||
@require CuArrays begin
|
||||
global default_adaptor = CuArray
|
||||
end
|
||||
|
||||
gpu(x) = adapt(default_adaptor, x)
|
||||
|
Loading…
Reference in New Issue
Block a user