logsoftmax tests
This commit is contained in:
parent
a4bf5936b0
commit
f9be72f545
|
@ -17,10 +17,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
|
||||
## uncomment the following test when logsoftmax has been added into NNlib.jl
|
||||
#@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
||||
#@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
||||
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
|
||||
|
||||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
|
Loading…
Reference in New Issue