harder test
This commit is contained in:
parent
8019f789f8
commit
cf7dd34767
@ -15,7 +15,7 @@ x = Flux.onehotbatch([1, 2, 3], 1:3)
|
|||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||||
|
|
||||||
m = Chain(Dense(10, 5, σ), Dense(5, 2))
|
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
|
|
||||||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||||
|
Loading…
Reference in New Issue
Block a user