diff --git a/docs/make.jl b/docs/make.jl index 2f24a022..b4e1b8b0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,8 +8,9 @@ makedocs(modules=[Flux, NNlib], "Building Models" => ["Basics" => "models/basics.md", "Recurrence" => "models/recurrence.md", - "Regularisation" => "models/regularisation.md", "Model Reference" => "models/layers.md", + "Loss Functions" => "models/losses.md", + "Regularisation" => "models/regularisation.md", "Advanced Model Building" => "models/advanced.md", "NNlib" => "models/nnlib.md"], "Handling Data" => diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 54ce5791..12805181 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -67,22 +67,4 @@ Many normalisation layers behave differently under training and inference (testi ```@docs Flux.testmode! trainmode! -``` - -## Cost Functions -```@docs -Flux.mae -Flux.mse -Flux.msle -Flux.huber_loss -Flux.crossentropy -Flux.logitcrossentropy -Flux.binarycrossentropy -Flux.logitbinarycrossentropy -Flux.kldivergence -Flux.poisson -Flux.hinge -Flux.squared_hinge -Flux.dice_coeff_loss -Flux.tversky_loss -``` +``` \ No newline at end of file diff --git a/docs/src/models/losses.md b/docs/src/models/losses.md new file mode 100644 index 00000000..fbf44b2b --- /dev/null +++ b/docs/src/models/losses.md @@ -0,0 +1,25 @@ +## Loss Functions +Flux provides a large number of common loss functions used for training machine learning models. + +Most loss functions in Flux have an optional argument `agg`, denoting the type of aggregation performed over the +batch: +```julia +loss(ŷ, y; agg=mean) +``` + +```@docs +Flux.mae +Flux.mse +Flux.msle +Flux.huber_loss +Flux.crossentropy +Flux.logitcrossentropy +Flux.binarycrossentropy +Flux.logitbinarycrossentropy +Flux.kldivergence +Flux.poisson_loss +Flux.hinge +Flux.squared_hinge +Flux.dice_coeff_loss +Flux.tversky_loss +``` \ No newline at end of file diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index 535dd096..ee4350f0 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -7,9 +7,10 @@ add the result to the overall loss. For example, say we have a simple regression. ```julia -using Flux: crossentropy +using Flux +using Flux: logitcrossentropy m = Dense(10, 5) -loss(x, y) = crossentropy(softmax(m(x)), y) +loss(x, y) = logitcrossentropy(m(x), y) ``` We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`. @@ -18,19 +19,19 @@ We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b using LinearAlgebra penalty() = norm(m.W) + norm(m.b) -loss(x, y) = crossentropy(softmax(m(x)), y) + penalty() +loss(x, y) = logitcrossentropy(m(x), y) + penalty() ``` When working with layers, Flux provides the `params` function to grab all -parameters at once. We can easily penalise everything with `sum(norm, params)`. +parameters at once. We can easily penalise everything with `sum`: ```julia -julia> params(m) +julia> Flux.params(m) 2-element Array{Any,1}: param([0.355408 0.533092; … 0.430459 0.171498]) param([0.0, 0.0, 0.0, 0.0, 0.0]) -julia> sum(norm, params(m)) +julia> sum(norm, Flux.params(m)) 26.01749952921026 ``` @@ -40,9 +41,9 @@ Here's a larger example with a multi-layer perceptron. m = Chain( Dense(28^2, 128, relu), Dense(128, 32, relu), - Dense(32, 10), softmax) + Dense(32, 10)) -loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m)) +loss(x, y) = logitcrossentropy(m(x), y) + sum(norm, Flux.params(m)) loss(rand(28^2), rand(10)) ``` diff --git a/src/Flux.jl b/src/Flux.jl index 5799fe42..3c770ac4 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -31,6 +31,7 @@ include("onehot.jl") include("functor.jl") include("layers/stateless.jl") +include("layers/losses.jl") include("layers/basic.jl") include("layers/conv.jl") include("layers/recurrent.jl") diff --git a/src/layers/losses.jl b/src/layers/losses.jl new file mode 100644 index 00000000..35900b51 --- /dev/null +++ b/src/layers/losses.jl @@ -0,0 +1,190 @@ +# Cost functions +""" + mae(ŷ, y; agg=mean) + +Return the loss corresponding to mean absolute error: + + agg(abs.(ŷ .- y)) +""" +mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y)) + +""" + mse(ŷ, y; agg=mean) + +Return the loss corresponding to mean square error: + + agg((ŷ .- y).^2) +""" +mse(ŷ, y; agg=mean) = agg((ŷ .- y).^2) + +""" + msle(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))) + +The loss corresponding to mean squared logarithmic errors, calculated as + + agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) + +The `ϵ` term provides numerical stability. +Penalizes an under-predicted estimate more than an over-predicted estimate. +""" +msle(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) + + +""" + huber_loss(ŷ, y; δ=1, agg=mean) + +Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) +given the prediction `ŷ` and true values `y`. + + | 0.5 * |ŷ - y|, for |ŷ - y| <= δ + Huber loss = | + | δ * (|ŷ - y| - 0.5 * δ), otherwise +""" +function huber_loss(ŷ, y; agg=mean, δ=ofeltype(ŷ, 1)) + abs_error = abs.(ŷ .- y) + temp = abs_error .< δ + x = ofeltype(ŷ, 0.5) + agg(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) +end + +wsum(w::Nothing, x; dims) = sum(x, dims=dims) +wsum(w::Number, x; dims) = w .* sum(x, dims=dims) +wsum(w::AbstractArray, x; dims) = sum( w .* x, dims=dims) + +""" + crossentropy(ŷ, y; weight=nothing, dims=1, ϵ=eps(eltype(ŷ)), agg=mean) + +Return the cross entropy between the given probability distributions; +calculated as + + agg(.-sum(weight .* y .* log.(ŷ .+ ϵ); dims=dims))agg=mean, + +`weight` can be `nothing`, a number or an array. +`weight=nothing` acts like `weight=1` but is faster. + +See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) +""" +function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ)), weight=nothing) + agg(.-wsum(weight, y .* log.(ŷ .+ ϵ); dims=dims)) +end + +""" + logitcrossentropy(ŷ, y; weight=nothing, agg=mean, dims=1) + +Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation; +calculated as + + agg(.-sum(weight .* y .* logsoftmax(ŷ; dims=dims); dims=dims)) + +`logitcrossentropy(ŷ, y)` is mathematically equivalent to +[`Flux.crossentropy(softmax(log.(ŷ)), y)`](@ref) but it is more numerically stable. + +See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) +""" +function logitcrossentropy(ŷ, y; dims=1, agg=mean, weight=nothing) + agg(.-wsum(weight, y .* logsoftmax(ŷ; dims=dims); dims=dims)) +end + +""" + binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) + +Return ``-y*\\log(ŷ + ϵ) - (1-y)*\\log(1-ŷ + ϵ)``. The `ϵ` term provides numerical stability. + +Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation. + +See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) +""" +function binarycrossentropy(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))) + agg(@.(-y*log(ŷ+ϵ) - (1-y)*log(1-ŷ+ϵ))) +end + +# Re-definition to fix interaction with CuArrays. +# CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) + +""" + logitbinarycrossentropy(ŷ, y; agg=mean) + +`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to +[`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable. + +See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref) +""" +function logitbinarycrossentropy(ŷ, y; agg=mean) + agg(@.((1-y)*ŷ - logsigmoid(ŷ))) +end +# Re-definition to fix interaction with CuArrays. +# CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ) + + +""" + kldivergence(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ))) + +Return the +[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) +between the given arrays interpreted as probability distributions. + +KL divergence is a measure of how much one probability distribution is different +from the other. +It is always non-negative and zero only when both the distributions are equal +everywhere. +""" +function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ))) + entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims)) + cross_entropy = crossentropy(ŷ, y; dims=dims, agg=agg, ϵ=ϵ) + return entropy + cross_entropy +end + +""" + poisson_loss(ŷ, y; agg=mean) + +# Return how much the predicted distribution `ŷ` diverges from the expected Poisson +# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`. +REDO +[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). +""" +poisson_loss(ŷ, y; agg=mean) = agg(ŷ .- y .* log.(ŷ)) + +@deprecate poisson poisson_loss + +""" + hinge(ŷ, y; agg=mean) + +Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the +prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as +`agg(max.(0, 1 .- ŷ .* y))`. + +See also: [`squared_hinge`](@ref) +""" +hinge(ŷ, y; agg=mean) = agg(max.(0, 1 .- ŷ .* y)) + +""" + squared_hinge(ŷ, y; agg=mean) + +Return the squared hinge loss given the prediction `ŷ` and true labels `y` +(containing 1 or -1); calculated as `agg((max.(0, 1 .- ŷ .* y)).^2))`. + +See also: [`hinge`](@ref) +""" +squared_hinge(ŷ, y; agg=mean) = agg((max.(0, 1 .- ŷ .* y)).^2) + +""" + dice_coeff_loss(ŷ, y; smooth=1) + +Return a loss based on the dice coefficient. +Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation +architecture. +Similar to the F1_score. Calculated as: + 1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)` +""" +dice_coeff_loss(ŷ, y; smooth=ofeltype(ŷ, 1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth) #TODO + +""" + tversky_loss(ŷ, y; β=0.7) + +Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf). +Used with imbalanced data to give more weight to false negatives. +Larger β weigh recall higher than precision (by placing more emphasis on false negatives) +Calculated as: + 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1) +""" +tversky_loss(ŷ, y; β=ofeltype(ŷ, 0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1) #TODO \ No newline at end of file diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index d3b37980..40c9d689 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,167 +1,3 @@ -# Cost functions -""" - mae(ŷ, y; agg=mean) - -Return the Mean Absolute Error. - - l = abs.(ŷ .- y) - -The results -""" -mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y)) - -""" - mse(ŷ, y) - -Return the mean squared error between ŷ and y; calculated as -`sum((ŷ .- y).^2) / length(y)`. - -# Examples -```jldoctest -julia> Flux.mse([0, 2], [1, 1]) -1//1 -``` -""" -mse(ŷ, y; agg=mean) = agg((ŷ .- y).^2) - -""" - msle(ŷ, y; ϵ=eps(eltype(ŷ))) - -Return the mean of the squared logarithmic errors; calculated as -`sum((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) / length(y)`. -The `ϵ` term provides numerical stability. - -Penalizes an under-predicted estimate greater than an over-predicted estimate. -""" -msle(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))) = agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2) - - - -""" - huber_loss(ŷ, y; δ=1) - -Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss) -given the prediction `ŷ` and true values `y`. - - | 0.5 * |ŷ - y|, for |ŷ - y| <= δ - Huber loss = | - | δ * (|ŷ - y| - 0.5 * δ), otherwise -""" -function huber_loss(ŷ, y; agg=mean, δ=ofeltype(ŷ, 1)) - abs_error = abs.(ŷ .- y) - temp = abs_error .< δ - x = ofeltype(ŷ, 0.5) - agg(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) -end - -# function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing) -# return -sum(y .* log.(ŷ)) * 1 // size(y, 2) -# end - -# function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Number) -# return -sum(y .* log.(ŷ)) .* weight * 1 // size(y, 2) -# end - -# function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::AbstractVector) -# return -sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2) -# end - -""" - crossentropy(ŷ, y; weight = nothing) - -Return the cross entropy between the given probability distributions; -calculated as `-sum(y .* log.(ŷ) .* weight) / size(y, 2)`. - -`weight` can be `Nothing`, a `Number` or an `AbstractVector`. -`weight=nothing` acts like `weight=1` but is faster. - -See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) - -# Examples -```jldoctest -julia> Flux.crossentropy(softmax([-1.1491, 0.8619, 0.3127]), [1, 1, 0]) -3.085467254747739 -``` -""" -# crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight) -function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ))) - agg(.-sum(y .* log.(ŷ .+ ϵ); dims=dims)) -end - -""" - logitcrossentropy(ŷ, y; weight = 1) - -Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation; -calculated as `-sum(y .* logsoftmax(ŷ) .* weight) / size(y, 2)`. - -`logitcrossentropy(ŷ, y)` is mathematically equivalent to -[`Flux.crossentropy(softmax(log(ŷ)), y)`](@ref) but it is more numerically stable. - -See also: [`Flux.crossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) - -# Examples -```jldoctest -julia> Flux.logitcrossentropy([-1.1491, 0.8619, 0.3127], [1, 1, 0]) -3.085467254747738 -``` -""" -# function logitcrossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) -# return -sum(y .* logsoftmax(ŷ) .* weight) * 1 // size(y, 2) -# end -function logitcrossentropy(ŷ, y; dims=1, agg=mean) - agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims)) -end - -""" - binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) - -Return ``-y*\\log(ŷ + ϵ) - (1-y)*\\log(1-ŷ + ϵ)``. The `ϵ` term provides numerical stability. - -Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation. - -See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref) - -# Examples -```jldoctest -julia> Flux.binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0]) -3-element Array{Float64,1}: - 1.424397097347566 - 0.35231664672364077 - 0.8616703662235441 -``` -""" -# binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) -function binarycrossentropy(ŷ, y; agg=mean, ϵ=eps(eltype(ŷ))) - agg(@.(-y*log(ŷ+ϵ) - (1-y)*log(1-ŷ+ϵ))) -end -# Re-definition to fix interaction with CuArrays. -# CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) - -""" - logitbinarycrossentropy(ŷ, y) - -`logitbinarycrossentropy(ŷ, y)` is mathematically equivalent to -[`Flux.binarycrossentropy(σ(log(ŷ)), y)`](@ref) but it is more numerically stable. - -See also: [`Flux.crossentropy`](@ref), [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref) - -# Examples -```jldoctest -julia> Flux.logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0]) -3-element Array{Float64,1}: - 1.4243970973475661 - 0.35231664672364094 - 0.8616703662235443 -``` -""" -# logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ) - -function logitcrossentropy(ŷ, y; agg=mean) - agg(@.((1-y)*ŷ - logsigmoid(ŷ))) -end -# Re-definition to fix interaction with CuArrays. -# CuArrays.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ) - # TODO normalise over last dimension is typically what you want to do. # Possible deprecation path: `normalise(x; dims=1)` -> `normalise(x; dims)` -> `normalise(x; dims=size(x)[end])` """ @@ -197,77 +33,6 @@ function normalise(x::AbstractArray; dims=1, ϵ=ofeltype(x, 1e-6)) return (x .- μ′) ./ (σ′.+ ϵ) end -""" - kldivergence(ŷ, y) - -Return the -[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) -between the given probability distributions. - -KL divergence is a measure of how much one probability distribution is different -from the other. -It is always non-negative and zero only when both the distributions are equal -everywhere. -""" -function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=eps(eltype(ŷ))) - entropy = agg(sum(y .* log.(y .+ ϵ), dims=dims)) - cross_entropy = crossentropy(ŷ, y; dims=dims, agg=agg, ϵ=ϵ) - return entropy + cross_entropy -end - -""" - poisson(ŷ, y) - -# Return how much the predicted distribution `ŷ` diverges from the expected Poisson -# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`. -REDO -[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). -""" -poisson(ŷ, y; agg=mean) = agg(ŷ .- y .* log.(ŷ)) - -""" - hinge(ŷ, y) - -Return the [hinge loss](https://en.wikipedia.org/wiki/Hinge_loss) given the -prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as -`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`. - -See also: [`squared_hinge`](@ref) -""" -hinge(ŷ, y; agg=mean) = agg(max.(0, 1 .- ŷ .* y)) - -""" - squared_hinge(ŷ, y) - -Return the squared hinge loss given the prediction `ŷ` and true labels `y` -(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`. - -See also: [`hinge`](@ref) -""" -squared_hinge(ŷ, y; agg=mean) = agg((max.(0, 1 .- ŷ .* y)).^2) - -""" - dice_coeff_loss(ŷ, y; smooth=1) - -Return a loss based on the dice coefficient. -Used in the [V-Net](https://arxiv.org/pdf/1606.04797v1.pdf) image segmentation -architecture. -Similar to the F1_score. Calculated as: - 1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)` -""" -dice_coeff_loss(ŷ, y; smooth=ofeltype(ŷ, 1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth) #TODO - -""" - tversky_loss(ŷ, y; β=0.7) - -Return the [Tversky loss](https://arxiv.org/pdf/1706.05721.pdf). -Used with imbalanced data to give more weight to false negatives. -Larger β weigh recall higher than precision (by placing more emphasis on false negatives) -Calculated as: - 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1) -""" -tversky_loss(ŷ, y; β=ofeltype(ŷ, 0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1) #TODO - """ flatten(x::AbstractArray) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 128e5c7d..cc87169f 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -33,8 +33,8 @@ cx = gpu(x) x = [-1.1491, 0.8619, 0.3127] y = [1, 1, 0.] -@test Flux.binarycrossentropy.(σ.(x),y) ≈ Array(Flux.binarycrossentropy.(cu(σ.(x)),cu(y))) -@test Flux.logitbinarycrossentropy.(x,y) ≈ Array(Flux.logitbinarycrossentropy.(cu(x),cu(y))) +@test Flux.binarycrossentropy(σ.(x), y) ≈ Flux.binarycrossentropy(cu(σ.(x)), cu(y)) +@test Flux.logitbinarycrossentropy(x, y) ≈ Flux.logitbinarycrossentropy(cu(x), cu(y)) xs = rand(5, 5) ys = Flux.onehotbatch(1:5,1:5) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index ebcd815c..18d4d640 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -56,12 +56,12 @@ const ϵ = 1e-7 logŷ, y = randn(3), rand(3) @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ)) - @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ))) + @test binarycrossentropy(σ.(logŷ), y; ϵ=0) ≈ mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))) + @test binarycrossentropy(σ.(logŷ), y) ≈ mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))) end @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0) + @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(σ.(logŷ), y; ϵ=0) end y = [1 2 3] @@ -86,8 +86,8 @@ const ϵ = 1e-7 y = [0.1 0.2 0.3] ŷ = [0.4 0.5 0.6] @testset "poisson" begin - @test Flux.poisson(ŷ, y) ≈ 0.6278353988097339 - @test Flux.poisson(y, y) ≈ 0.5044459776946685 + @test Flux.poisson_loss(ŷ, y) ≈ 0.6278353988097339 + @test Flux.poisson_loss(y, y) ≈ 0.5044459776946685 end y = [1.0 0.5 0.3 2.4]