diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 3f97e1fd..b26073f3 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -48,7 +48,7 @@ given the prediction `ŷ` and true values `y`. """ function huber_loss(ŷ, y; δ=eltype(ŷ)(1)) abs_error = abs.(ŷ .- y) - temp = abs_error .< δ + temp = Zygote.dropgrad(abs_error .< δ) x = eltype(ŷ)(0.5) hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y) end