Removed spurious promotions

This commit is contained in:
Adarsh Kumar 2020-02-06 01:06:41 +05:30 committed by GitHub
parent b5184553d4
commit 7710bb0b4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 6 deletions

View File

@ -6,7 +6,7 @@ using NNlib: logsoftmax, logσ
mae(, y) mae(, y)
L1 loss function. Computes the mean of absolute error between prediction and true values L1 loss function. Computes the mean of absolute error between prediction and true values
""" """
mae(, y) = sum(abs.(.- y)) * 1 // length(y) mae(, y) = sum(abs.( .- y)) * 1 // length(y)
""" """
@ -42,9 +42,9 @@ Alias:
msle(,y;ϵ1=eps.(Float64.()),ϵ2=eps.(Float64.(y))) msle(,y;ϵ1=eps.(Float64.()),ϵ2=eps.(Float64.(y)))
""" """
mean_squared_logarithmic_error(, y;ϵ1=eps.(Float64.()),ϵ2=eps.(Float64.(y))) = sum((log.(+ϵ1).-log.(y+ϵ2)).^2) * 1 // length(y) mean_squared_logarithmic_error(, y;ϵ1=eps.(),ϵ2=eps.(eltype().(y))) = sum((log.(+ϵ1).-log.(y+ϵ2)).^2) * 1 // length(y)
#Alias #Alias
msle(, y;ϵ1=eps.(Float64.()),ϵ2=eps.(Float64.(y))) = sum((log.(+ϵ1).-log.(y+ϵ2)).^2) * 1 // length(y) msle(, y;ϵ1=eps.(),ϵ2=eps.(eltype().(y))) = sum((log.(+ϵ1).-log.(y+ϵ2)).^2) * 1 // length(y)
@ -74,12 +74,14 @@ Computes the mean of the Huber loss between prediction ŷ and true values y. By
""" """
function huber_loss(, y,delta=1.0) function huber_loss(, y,delta=1.0)
abs_error = abs.(.-y) abs_error = abs.(.-y)
hub_loss =0 type_ = eltype()
delta = type_(delta)
hub_loss =type_(0)
for i in 1:length(y) for i in 1:length(y)
if (abs_error[i]<=delta) if (abs_error[i]<=delta)
hub_loss+=abs_error[i]^2*0.5 hub_loss+=abs_error[i]^2*type_(0.5)
else else
hub_loss+=delta*(abs_error[i]-0.5*delta) hub_loss+=delta*(abs_error[i]-type_(0.5*delta))
end end
return hub_loss*1//length(y) return hub_loss*1//length(y)