cpu/gpu conveniences
This commit is contained in:
parent
15d1d3256b
commit
466b5c501a
@ -10,7 +10,7 @@ using MacroTools: @forward
|
|||||||
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
|
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
|
||||||
Dropout, LayerNorm, BatchNorm,
|
Dropout, LayerNorm, BatchNorm,
|
||||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||||
param, params, mapleaves
|
param, params, mapleaves, cpu, gpu
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
|
|
||||||
|
@ -40,3 +40,15 @@ function params(m)
|
|||||||
end
|
end
|
||||||
|
|
||||||
params(m...) = params(m)
|
params(m...) = params(m)
|
||||||
|
|
||||||
|
# CPU/GPU movement conveniences
|
||||||
|
|
||||||
|
cpu(x) = adapt(Array, x)
|
||||||
|
|
||||||
|
default_adaptor = Array
|
||||||
|
|
||||||
|
@require CuArrays begin
|
||||||
|
global default_adaptor = CuArray
|
||||||
|
end
|
||||||
|
|
||||||
|
gpu(x) = adapt(default_adaptor, x)
|
||||||
|
Loading…
Reference in New Issue
Block a user