Make the vector of weights test pass on GPU

This commit is contained in:
Katharine Hyatt 2019-10-23 09:53:09 -04:00
parent f7ce717aaa
commit 8913c9c741
1 changed files with 1 additions and 1 deletions

View File

@ -29,7 +29,7 @@ x = [1,2,3]
cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
@test Flux.crossentropy(x,x, weight=1.0) Flux.crossentropy(cx,cx, weight=1.0)
@test_broken Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) Flux.crossentropy(cx,cx, weight=[1.0;2.0;3.0])
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))
xs = rand(5, 5)
ys = Flux.onehotbatch(1:5,1:5)