use kwarg

This commit is contained in:
Mike J Innes 2017-12-13 15:27:15 +00:00
parent 128725cefd
commit e3a688e706
2 changed files with 6 additions and 12 deletions

View File

@ -4,16 +4,10 @@ using NNlib: log_fast
mse(, y) = sum(( .- y).^2)/length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat)
return -sum(y .* log_fast.()) / size(y, 2)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* log_fast.() .* weight) / size(y, 2)
end
function weighted_crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat)
return -sum(y .* log_fast.() .* w) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)

View File

@ -1,4 +1,4 @@
using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy
using Flux: onehotbatch, mse, crossentropy
@testset "losses" begin
# First, regression-style y's
@ -19,8 +19,8 @@ using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy
end
@testset "weighted_crossentropy" begin
@test weighted_crossentropy(y_hat, y, ones(2)) y_logloss
@test weighted_crossentropy(y_hat, y, [.5, .5]) y_logloss/2
@test weighted_crossentropy(y_hat, y, [2, .5]) 1.5049660054074199
@test crossentropy(y_hat, y, weight = ones(2)) y_logloss
@test crossentropy(y_hat, y, weight = [.5, .5]) y_logloss/2
@test crossentropy(y_hat, y, weight = [2, .5]) 1.5049660054074199
end
end