Preserve element type in stateless layers
This should fix issues where Float32 weights and activations get suddenly switched to Float64 on the backwards pass
This commit is contained in:
parent
f0d5624ed2
commit
c59de875d4
|
@ -2,16 +2,21 @@ using NNlib: logsoftmax, logσ
|
|||
|
||||
# Cost functions
|
||||
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
function mse(ŷ, y)
|
||||
x = sum((ŷ .- y).^2)
|
||||
return x/Tracker.tracked_eltype(x)(length(y))
|
||||
end
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
x = -sum(y .* log.(ŷ) .* weight)
|
||||
return x/Tracker.tracked_eltype(x)(size(y, 2))
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
|
||||
x = -sum(y .* logsoftmax(logŷ) .* weight)
|
||||
return x/Tracker.tracked_eltype(x)(size(y, 2))
|
||||
end
|
||||
|
||||
"""
|
||||
|
|
|
@ -494,3 +494,8 @@ if VERSION < v"1.1.0-DEV.548"
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Get the element type of an array or the element type of the inner array of a tracked array
|
||||
tracked_eltype(x) = eltype(x)
|
||||
tracked_eltype(x::TrackedArray) = eltype(data(x))
|
||||
tracked_eltype(x::TrackedReal{T}) where T = T
|
||||
|
|
Loading…
Reference in New Issue