update gpu api
This commit is contained in:
parent
ccef9f4dd4
commit
2eb38eedbf
|
@ -10,7 +10,6 @@ function treelike(T, fs = fieldnames(T))
|
|||
@eval current_module() begin
|
||||
children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
mapchildren(f, x::$T) = $T(f.(children(x))...)
|
||||
adapt(T, x::$T) = mapleaves(x -> adapt(T, x), x)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -43,12 +42,12 @@ params(m...) = params(m)
|
|||
|
||||
# CPU/GPU movement conveniences
|
||||
|
||||
cpu(x) = adapt(Array, x)
|
||||
cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
||||
|
||||
default_adaptor = identity
|
||||
gpu_adaptor = identity
|
||||
|
||||
@require CuArrays begin
|
||||
global default_adaptor = CuArrays.cu
|
||||
global gpu_adaptor = CuArrays.cu
|
||||
end
|
||||
|
||||
gpu(x) = adapt(default_adaptor, x)
|
||||
gpu(x) = mapleaves(gpu_adaptor, x)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||
using Flux: gpu
|
||||
|
||||
info("Testing Flux/GPU")
|
||||
|
||||
|
@ -7,18 +8,18 @@ info("Testing Flux/GPU")
|
|||
CuArrays.allowscalar(false)
|
||||
|
||||
x = param(randn(5, 5))
|
||||
cx = cu(x)
|
||||
cx = gpu(x)
|
||||
@test cx isa TrackedArray && cx.data isa CuArray
|
||||
|
||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = cu(x)
|
||||
cx = gpu(x)
|
||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
|
||||
m = Chain(Dense(10, 5, σ), Dense(5, 2))
|
||||
cm = cu(m)
|
||||
cm = gpu(m)
|
||||
|
||||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
@test cm(cu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
|
||||
end
|
||||
|
||||
|
|
|
@ -5,14 +5,14 @@ info("Testing Flux/CUDNN")
|
|||
@testset "RNN" begin
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
rnn = R(10, 5)
|
||||
curnn = mapleaves(cu, rnn)
|
||||
curnn = mapleaves(gpu, rnn)
|
||||
@testset for batch_size in (1, 5)
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
x = batch_size == 1 ?
|
||||
param(rand(10)) :
|
||||
param(rand(10,batch_size))
|
||||
cux = cu(x)
|
||||
cux = gpu(x)
|
||||
y = (rnn(x); rnn(x))
|
||||
cuy = (curnn(cux); curnn(cux))
|
||||
|
||||
|
@ -22,7 +22,7 @@ info("Testing Flux/CUDNN")
|
|||
Δ = randn(size(y))
|
||||
|
||||
Flux.back!(y, Δ)
|
||||
Flux.back!(cuy, cu(Δ))
|
||||
Flux.back!(cuy, gpu(Δ))
|
||||
|
||||
@test x.grad ≈ collect(cux.grad)
|
||||
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
||||
|
@ -38,7 +38,7 @@ info("Testing Flux/CUDNN")
|
|||
ohx = batch_size == 1 ?
|
||||
Flux.onehot(rand(1:10), 1:10) :
|
||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||
cuohx = cu(ohx)
|
||||
cuohx = gpu(ohx)
|
||||
y = (rnn(ohx); rnn(ohx))
|
||||
cuy = (curnn(cuohx); curnn(cuohx))
|
||||
|
||||
|
|
Loading…
Reference in New Issue