basic GPU tests

This commit is contained in:
Mike J Innes 2018-01-16 17:58:14 +00:00
parent 1beb30e19a
commit 4207fb98f2
4 changed files with 29 additions and 0 deletions

View File

@ -26,6 +26,7 @@ end
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)

View File

@ -1,3 +1,5 @@
import Adapt: adapt
children(x) = ()
mapchildren(f, x) = x
@ -8,6 +10,7 @@ function treelike(T, fs = fieldnames(T))
@eval begin
children(x::$T) = ($([:(x.$f) for f in fs]...),)
mapchildren(f, x::$T) = $T(f.(children(x))...)
adapt(T, x::$T) = mapchildren(x -> adapt(T, x), x)
end
end

21
test/cuarrays.jl Normal file
View File

@ -0,0 +1,21 @@
using Flux, Flux.Tracker, CuArrays, Base.Test
@testset "CuArrays" begin
CuArrays.allowscalar(false)
x = param(randn(5, 5))
cx = cu(x)
@test cx isa TrackedArray && cx.data isa CuArray
x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = cu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
m = Chain(Dense(10, 5, σ), Dense(5, 2))
cm = cu(m)
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
@test cm(cu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
end

View File

@ -9,4 +9,8 @@ include("layers/stateless.jl")
include("optimise.jl")
include("data.jl")
if Base.find_in_path("CuArrays") nothing
include("cuarrays.jl")
end
end