Merge branch 'patch-6' of https://github.com/thebhatman/Flux.jl into patch-6
This commit is contained in:
commit
6e289ef939
@ -32,6 +32,18 @@ Flux.train!(loss, ps, data, opt)
|
|||||||
|
|
||||||
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
The objective will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `crossentropy` for cross entropy loss, but you can calculate it however you want.
|
||||||
|
|
||||||
|
In-built loss functions:
|
||||||
|
```@docs
|
||||||
|
mse
|
||||||
|
crossentropy
|
||||||
|
logitcrossentropy
|
||||||
|
binarycrossentropy
|
||||||
|
logitbinarycrossentropy
|
||||||
|
kldivergence
|
||||||
|
poisson
|
||||||
|
hinge
|
||||||
|
```
|
||||||
|
|
||||||
## Datasets
|
## Datasets
|
||||||
|
|
||||||
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
|
The `data` argument provides a collection of data to train with (usually a set of inputs `x` and target outputs `y`). For example, here's a dummy data set with only one data point:
|
||||||
|
@ -54,3 +54,25 @@ function normalise(x::AbstractArray, dims)
|
|||||||
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
||||||
normalise(x, dims = dims)
|
normalise(x, dims = dims)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Kullback Leibler Divergence(KL Divergence)
|
||||||
|
KLDivergence is a measure of how much one probability distribution is different from the other.
|
||||||
|
It is always non-negative and zero only when both the distributions are equal everywhere.
|
||||||
|
https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
|
||||||
|
"""
|
||||||
|
function kldivergence(ŷ, y)
|
||||||
|
entropy = sum(y .* log.(y)) *1 //size(y,2)
|
||||||
|
cross_entropy = crossentropy(ŷ, y)
|
||||||
|
return entropy + cross_entropy
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Poisson Loss function
|
||||||
|
Poisson loss function is a measure of how the predicted distribution diverges from the expected distribution.
|
||||||
|
https://isaacchanghau.github.io/post/loss_functions/
|
||||||
|
"""
|
||||||
|
poisson(ŷ, y) = sum(ŷ .- y .* log.(ŷ)) *1 // size(y,2)
|
||||||
|
|
||||||
|
hinge(ŷ, y) = sum(max.(0, 1 .- ŷ .* y)) *1 // size(y,2)
|
||||||
|
|
||||||
|
@ -49,7 +49,28 @@ const ϵ = 1e-7
|
|||||||
@testset "logitbinarycrossentropy" begin
|
@testset "logitbinarycrossentropy" begin
|
||||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
y = [1 2 3]
|
||||||
|
y1 = [4.0 5.0 6.0]
|
||||||
|
@testset "kldivergence" begin
|
||||||
|
@test Flux.kldivergence(y, y1) ≈ 4.761838062403337
|
||||||
|
@test Flux.kldivergence(y, y) ≈ 0
|
||||||
|
end
|
||||||
|
|
||||||
|
y = [1 2 3 4]
|
||||||
|
y1 = [5.0 6.0 7.0 8.0]
|
||||||
|
@testset "hinge" begin
|
||||||
|
@test Flux.hinge(y, y1) ≈ 0
|
||||||
|
@test Flux.hinge(y, 0.5 .* y) ≈ 0.125
|
||||||
|
end
|
||||||
|
|
||||||
|
y = [0.1 0.2 0.3]
|
||||||
|
y1 = [0.4 0.5 0.6]
|
||||||
|
@testset "poisson" begin
|
||||||
|
@test Flux.poisson(y, y1) ≈ 1.0160455586700767
|
||||||
|
@test Flux.poisson(y, y) ≈ 0.5044459776946685
|
||||||
|
end
|
||||||
|
|
||||||
@testset "no spurious promotions" begin
|
@testset "no spurious promotions" begin
|
||||||
for T in (Float16, Float32, Float64)
|
for T in (Float16, Float32, Float64)
|
||||||
y = rand(T, 2)
|
y = rand(T, 2)
|
||||||
|
Loading…
Reference in New Issue
Block a user