basic GPU tests
This commit is contained in:
parent
1beb30e19a
commit
4207fb98f2
|
@ -26,6 +26,7 @@ end
|
|||
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import Adapt: adapt
|
||||
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
|
||||
|
@ -8,6 +10,7 @@ function treelike(T, fs = fieldnames(T))
|
|||
@eval begin
|
||||
children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
mapchildren(f, x::$T) = $T(f.(children(x))...)
|
||||
adapt(T, x::$T) = mapchildren(x -> adapt(T, x), x)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||
|
||||
@testset "CuArrays" begin
|
||||
|
||||
CuArrays.allowscalar(false)
|
||||
|
||||
x = param(randn(5, 5))
|
||||
cx = cu(x)
|
||||
@test cx isa TrackedArray && cx.data isa CuArray
|
||||
|
||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = cu(x)
|
||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
|
||||
m = Chain(Dense(10, 5, σ), Dense(5, 2))
|
||||
cm = cu(m)
|
||||
|
||||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
@test cm(cu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
|
||||
end
|
|
@ -9,4 +9,8 @@ include("layers/stateless.jl")
|
|||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
include("cuarrays.jl")
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue