use explicit converts

This commit is contained in:
Dhairya Gandhi 2019-01-03 19:31:56 +05:30
parent d54b0e312a
commit 65a41f2de6
2 changed files with 6 additions and 6 deletions

View File

@ -2,16 +2,16 @@ using NNlib: logsoftmax, logσ
# Cost functions # Cost functions
mse(, y) = sum(( .- y).^2)/length(y) mse(, y; efftype = eltype()) = sum(( .- y).^2)/convert(efftype, length(y))
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype())
-sum(y .* log.() .* weight) / size(y, 2) -sum(y .* log.() .* weight) / convert(efftype, size(y, 2))
end end
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype())
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2) return -sum(y .* logsoftmax(logŷ) .* weight) / convert(efftype, size(y, 2))
end end
""" """

View File

@ -72,7 +72,7 @@ for (M, f, arity) in DiffRules.diffrules()
f = :($M.$f) f = :($M.$f)
@eval begin @eval begin
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * convert(TrackedReal{eltype(Δ)}, $da), _zero(b)) @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b)