parent
d49d121a65
commit
1f2643c95c
|
@ -48,7 +48,7 @@ given the prediction `ŷ` and true values `y`.
|
||||||
"""
|
"""
|
||||||
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
||||||
abs_error = abs.(ŷ .- y)
|
abs_error = abs.(ŷ .- y)
|
||||||
temp = abs_error .< δ
|
temp = Zygote.dropgrad(abs_error .< δ)
|
||||||
x = eltype(ŷ)(0.5)
|
x = eltype(ŷ)(0.5)
|
||||||
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
|
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue