fix logitcrossentropy type

This commit is contained in:
Dhairya Gandhi 2019-01-03 19:45:38 +05:30
parent 65a41f2de6
commit 6a35fe07c9

View File

@ -10,7 +10,7 @@ end
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype()) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype(log))
return -sum(y .* logsoftmax(logŷ) .* weight) / convert(efftype, size(y, 2)) return -sum(y .* logsoftmax(logŷ) .* weight) / convert(efftype, size(y, 2))
end end