use kwarg
This commit is contained in:
parent
128725cefd
commit
e3a688e706
@ -4,16 +4,10 @@ using NNlib: log_fast
|
|||||||
|
|
||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||||
|
|
||||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
return -sum(y .* log_fast.(ŷ)) / size(y, 2)
|
return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
function weighted_crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat)
|
|
||||||
return -sum(y .* log_fast.(ŷ) .* w) / size(y, 2)
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
|
||||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy
|
using Flux: onehotbatch, mse, crossentropy
|
||||||
|
|
||||||
@testset "losses" begin
|
@testset "losses" begin
|
||||||
# First, regression-style y's
|
# First, regression-style y's
|
||||||
@ -19,8 +19,8 @@ using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy
|
|||||||
end
|
end
|
||||||
|
|
||||||
@testset "weighted_crossentropy" begin
|
@testset "weighted_crossentropy" begin
|
||||||
@test weighted_crossentropy(y_hat, y, ones(2)) ≈ y_logloss
|
@test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss
|
||||||
@test weighted_crossentropy(y_hat, y, [.5, .5]) ≈ y_logloss/2
|
@test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2
|
||||||
@test weighted_crossentropy(y_hat, y, [2, .5]) ≈ 1.5049660054074199
|
@test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user