fix tests

This commit is contained in:
CarloLucibello 2020-04-30 12:11:15 +02:00
parent 654b100ce3
commit b44ba162b1
4 changed files with 24 additions and 15 deletions

View File

@ -361,9 +361,9 @@ version = "1.2.11+9"
[[Zygote]]
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "08ee0b7796c4c9ce644b9ecc326f3e047486baeb"
git-tree-sha1 = "f7b0f77a86d2434abf693e3c0330e4682deed28d"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.4.17"
version = "0.4.18"
[[ZygoteRules]]
deps = ["MacroTools"]

View File

@ -50,7 +50,8 @@ wsum(w::Number, x; dims) = w .* sum(x, dims=dims)
wsum(w::AbstractArray, x; dims) = sum( w .* x, dims=dims)
"""
crossentropy(, y; weight=nothing, dims=1, ϵ=eps(eltype()), agg=mean)
crossentropy(, y; weight=nothing, dims=1, ϵ=eps(eltype()),
logits=false, agg=mean)
Return the cross entropy between the given probability distributions;
calculated as
@ -60,16 +61,22 @@ calculated as
`weight` can be `nothing`, a number or an array.
`weight=nothing` acts like `weight=1` but is faster.
If `logits=true`, the input `̂y` is first fed to a [`softmax`](@ref) layer.
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
"""
function crossentropy(, y; dims=1, agg=mean, ϵ=epseltype(), weight=nothing)
function crossentropy(, y; dims=1, agg=mean, ϵ=epseltype(),
weight=nothing, logits=false)
if logits
return logitcrossentropy(, y; dims=dims, agg=agg, weight=weight)
end
agg(.-wsum(weight, y .* log.( .+ ϵ); dims=dims))
end
"""
logitcrossentropy(, y; weight=nothing, agg=mean, dims=1)
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
Return the cross[1.0 0.5 0.3 2.4]entropy computed after a [`Flux.logsoftmax`](@ref) operation;
calculated as
agg(.-sum(weight .* y .* logsoftmax(; dims=dims); dims=dims))
@ -84,15 +91,18 @@ function logitcrossentropy(ŷ, y; dims=1, agg=mean, weight=nothing)
end
"""
binarycrossentropy(, y; agg=mean, ϵ=epseltype())
binarycrossentropy(, y; agg=mean, ϵ=epseltype(), logits=false)
Return ``-y*\\log( + ϵ) - (1-y)*\\log(1- + ϵ)``. The `ϵ` term provides numerical stability.
Typically, the prediction `` is given by the output of a [`sigmoid`](@ref) activation.
If `logits=true`, the input `̂y` is first fed to a [`sigmoid`](@ref) activation.
See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
"""
function binarycrossentropy(, y; agg=mean, ϵ=epseltype())
function binarycrossentropy(, y; agg=mean, ϵ=epseltype(), logits=false)
if logits
return logitbinarycrossentropy(, y; agg=agg)
end
agg(@.(-y*log(+ϵ) - (1-y)*log(1-+ϵ)))
end

View File

@ -7,7 +7,6 @@ nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of con
ofeltype(x, y) = convert(float(eltype(x)), y)
epseltype(x) = eps(float(eltype(x)))
"""
glorot_uniform(dims...)

View File

@ -89,18 +89,18 @@ const ϵ = 1e-7
@test Flux.poisson_loss(ŷ, y) 0.6278353988097339
@test Flux.poisson_loss(y, y) 0.5044459776946685
end
y = [1.0 0.5 0.3 2.4]
ŷ = [0 1.4 0.5 1.2]
@testset "dice_coeff_loss" begin
@test Flux.dice_coeff_loss(ŷ, y, dims=1) 0.2799999999999999
@test Flux.dice_coeff_loss(y, y, dims=1) 0.0
@test Flux.dice_coeff_loss(ŷ, y, dims=(1,2)) 0.2799999999999999
@test Flux.dice_coeff_loss(y, y, dims=(1,2)) 0.0
end
@testset "tversky_loss" begin
@test Flux.tversky_loss(ŷ, y, dims=1) -0.06772009029345383
@test Flux.tversky_loss(ŷ, y, dims=1, β = 0.8) -0.09490740740740744
@test Flux.tversky_loss(y, y, dims=1) -0.5576923076923075
@test Flux.tversky_loss(ŷ, y, dims=(1,2)) 0.036175710594315236
@test Flux.tversky_loss(ŷ, y, dims=(1,2), β = 0.8) 0.06281407035175879
@test Flux.tversky_loss(y, y, dims=(1,2)) -0.6904761904761902
end
@testset "no spurious promotions" begin