Merge #1129
1129: Added dropgrad in huber_loss r=CarloLucibello a=HenriDeh Workaround to prevent `iterate(::nothing)` when working with CuArrays. See issue #1128 Co-authored-by: HenriDeh <47037088+HenriDeh@users.noreply.github.com>
This commit is contained in:
commit
d9b07475b0
|
@ -46,9 +46,10 @@ given the prediction `ŷ` and true values `y`.
|
|||
Huber loss = |
|
||||
| δ * (|ŷ - y| - 0.5 * δ), otherwise
|
||||
"""
|
||||
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
|
||||
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
||||
abs_error = abs.(ŷ .- y)
|
||||
temp = abs_error .< δ
|
||||
temp = Zygote.dropgrad(abs_error .< δ)
|
||||
x = eltype(ŷ)(0.5)
|
||||
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue