Preserve element type in stateless layers

This should fix issues where Float32 weights and activations get
suddenly switched to Float64 on the backwards pass
This commit is contained in:
Elliot Saba 2019-01-11 02:11:10 -05:00
parent f0d5624ed2
commit c59de875d4
2 changed files with 13 additions and 3 deletions

View File

@ -2,16 +2,21 @@ using NNlib: logsoftmax, logσ
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
function mse(, y)
x = sum(( .- y).^2)
return x/Tracker.tracked_eltype(x)(length(y))
end
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) / size(y, 2)
x = -sum(y .* log.() .* weight)
return x/Tracker.tracked_eltype(x)(size(y, 2))
end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
x = -sum(y .* logsoftmax(logŷ) .* weight)
return x/Tracker.tracked_eltype(x)(size(y, 2))
end
"""

View File

@ -494,3 +494,8 @@ if VERSION < v"1.1.0-DEV.548"
end
end
end
# Get the element type of an array or the element type of the inner array of a tracked array
tracked_eltype(x) = eltype(x)
tracked_eltype(x::TrackedArray) = eltype(data(x))
tracked_eltype(x::TrackedReal{T}) where T = T