From 65a41f2de697ed131790dde332fab607cb7581d2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 3 Jan 2019 19:31:56 +0530 Subject: [PATCH] use explicit converts --- src/layers/stateless.jl | 10 +++++----- src/tracker/lib/real.jl | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 891ec230..a2d48b6e 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -2,16 +2,16 @@ using NNlib: logsoftmax, logσ # Cost functions -mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) +mse(ŷ, y; efftype = eltype(ŷ)) = sum((ŷ .- y).^2)/convert(efftype, length(y)) -function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) - -sum(y .* log.(ŷ) .* weight) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype(ŷ)) + -sum(y .* log.(ŷ) .* weight) / convert(efftype, size(y, 2)) end @deprecate logloss(x, y) crossentropy(x, y) -function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) - return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2) +function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1, efftype = eltype(ŷ)) + return -sum(y .* logsoftmax(logŷ) .* weight) / convert(efftype, size(y, 2)) end """ diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index bb2d8581..c5acf9fe 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -72,7 +72,7 @@ for (M, f, arity) in DiffRules.diffrules() f = :($M.$f) @eval begin @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) - @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * convert(TrackedReal{eltype(Δ)}, $da), _zero(b)) + @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b)