update gpu api

This commit is contained in:
Mike J Innes 2018-02-28 22:51:08 +00:00
parent ccef9f4dd4
commit 2eb38eedbf
3 changed files with 13 additions and 13 deletions

View File

@ -10,7 +10,6 @@ function treelike(T, fs = fieldnames(T))
@eval current_module() begin
children(x::$T) = ($([:(x.$f) for f in fs]...),)
mapchildren(f, x::$T) = $T(f.(children(x))...)
adapt(T, x::$T) = mapleaves(x -> adapt(T, x), x)
end
end
@ -43,12 +42,12 @@ params(m...) = params(m)
# CPU/GPU movement conveniences
cpu(x) = adapt(Array, x)
cpu(m) = mapleaves(x -> adapt(Array, x), m)
default_adaptor = identity
gpu_adaptor = identity
@require CuArrays begin
global default_adaptor = CuArrays.cu
global gpu_adaptor = CuArrays.cu
end
gpu(x) = adapt(default_adaptor, x)
gpu(x) = mapleaves(gpu_adaptor, x)

View File

@ -1,4 +1,5 @@
using Flux, Flux.Tracker, CuArrays, Base.Test
using Flux: gpu
info("Testing Flux/GPU")
@ -7,18 +8,18 @@ info("Testing Flux/GPU")
CuArrays.allowscalar(false)
x = param(randn(5, 5))
cx = cu(x)
cx = gpu(x)
@test cx isa TrackedArray && cx.data isa CuArray
x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = cu(x)
cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
m = Chain(Dense(10, 5, σ), Dense(5, 2))
cm = cu(m)
cm = gpu(m)
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
@test cm(cu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
end

View File

@ -5,14 +5,14 @@ info("Testing Flux/CUDNN")
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM]
rnn = R(10, 5)
curnn = mapleaves(cu, rnn)
curnn = mapleaves(gpu, rnn)
@testset for batch_size in (1, 5)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
param(rand(10)) :
param(rand(10,batch_size))
cux = cu(x)
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
@ -22,7 +22,7 @@ info("Testing Flux/CUDNN")
Δ = randn(size(y))
Flux.back!(y, Δ)
Flux.back!(cuy, cu(Δ))
Flux.back!(cuy, gpu(Δ))
@test x.grad collect(cux.grad)
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad)
@ -38,7 +38,7 @@ info("Testing Flux/CUDNN")
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = cu(ohx)
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))