fix logitcrossentropy type
This commit is contained in:
parent
65a41f2de6
commit
6a35fe07c9
@ -10,7 +10,7 @@ end
|
|||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
|
||||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype(ŷ))
|
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype(logŷ))
|
||||||
return -sum(y .* logsoftmax(logŷ) .* weight) / convert(efftype, size(y, 2))
|
return -sum(y .* logsoftmax(logŷ) .* weight) / convert(efftype, size(y, 2))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user