basic GPU tests
This commit is contained in:
parent
1beb30e19a
commit
4207fb98f2
@ -26,6 +26,7 @@ end
|
|||||||
|
|
||||||
children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
|
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||||
|
|
||||||
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import Adapt: adapt
|
||||||
|
|
||||||
children(x) = ()
|
children(x) = ()
|
||||||
mapchildren(f, x) = x
|
mapchildren(f, x) = x
|
||||||
|
|
||||||
@ -8,6 +10,7 @@ function treelike(T, fs = fieldnames(T))
|
|||||||
@eval begin
|
@eval begin
|
||||||
children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||||
mapchildren(f, x::$T) = $T(f.(children(x))...)
|
mapchildren(f, x::$T) = $T(f.(children(x))...)
|
||||||
|
adapt(T, x::$T) = mapchildren(x -> adapt(T, x), x)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
21
test/cuarrays.jl
Normal file
21
test/cuarrays.jl
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||||
|
|
||||||
|
@testset "CuArrays" begin
|
||||||
|
|
||||||
|
CuArrays.allowscalar(false)
|
||||||
|
|
||||||
|
x = param(randn(5, 5))
|
||||||
|
cx = cu(x)
|
||||||
|
@test cx isa TrackedArray && cx.data isa CuArray
|
||||||
|
|
||||||
|
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||||
|
cx = cu(x)
|
||||||
|
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||||
|
|
||||||
|
m = Chain(Dense(10, 5, σ), Dense(5, 2))
|
||||||
|
cm = cu(m)
|
||||||
|
|
||||||
|
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||||
|
@test cm(cu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||||
|
|
||||||
|
end
|
@ -9,4 +9,8 @@ include("layers/stateless.jl")
|
|||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
|
||||||
|
if Base.find_in_path("CuArrays") ≠ nothing
|
||||||
|
include("cuarrays.jl")
|
||||||
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user