Add dropgrad in huber_loss

Workaround for issue #1128
This commit is contained in:
HenriDeh 2020-04-17 13:34:04 +02:00 committed by GitHub
parent d49d121a65
commit 1f2643c95c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -48,7 +48,7 @@ given the prediction `ŷ` and true values `y`.
""" """
function huber_loss(, y; δ=eltype()(1)) function huber_loss(, y; δ=eltype()(1))
abs_error = abs.( .- y) abs_error = abs.( .- y)
temp = abs_error .< δ temp = Zygote.dropgrad(abs_error .< δ)
x = eltype()(0.5) x = eltype()(0.5)
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y) hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
end end