Update stateless.jl
This commit is contained in:
parent
1f2643c95c
commit
ac94754281
|
@ -46,6 +46,7 @@ given the prediction `ŷ` and true values `y`.
|
||||||
Huber loss = |
|
Huber loss = |
|
||||||
| δ * (|ŷ - y| - 0.5 * δ), otherwise
|
| δ * (|ŷ - y| - 0.5 * δ), otherwise
|
||||||
"""
|
"""
|
||||||
|
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
|
||||||
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
||||||
abs_error = abs.(ŷ .- y)
|
abs_error = abs.(ŷ .- y)
|
||||||
temp = Zygote.dropgrad(abs_error .< δ)
|
temp = Zygote.dropgrad(abs_error .< δ)
|
||||||
|
|
Loading…
Reference in New Issue