Compare commits

...

4 Commits

Author SHA1 Message Date
Dhairya Gandhi 6a35fe07c9 fix logitcrossentropy type 2019-01-03 19:45:38 +05:30
Dhairya Gandhi 65a41f2de6 use explicit converts 2019-01-03 19:31:56 +05:30
Dhairya Gandhi d54b0e312a correct convert dispatch 2019-01-03 11:42:48 +05:30
Dhairya Gandhi 9eaf26d1d7 type fixes 2019-01-03 11:42:48 +05:30
1 changed files with 5 additions and 5 deletions

View File

@ -2,16 +2,16 @@ using NNlib: logsoftmax, logσ
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
mse(, y; efftype = eltype()) = sum(( .- y).^2)/convert(efftype, length(y))
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) / size(y, 2)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype())
-sum(y .* log.() .* weight) / convert(efftype, size(y, 2))
end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype(logŷ))
return -sum(y .* logsoftmax(logŷ) .* weight) / convert(efftype, size(y, 2))
end
"""