Compare commits

...

1 Commits

Author SHA1 Message Date
Elliot Saba c59de875d4 Preserve element type in stateless layers
This should fix issues where Float32 weights and activations get
suddenly switched to Float64 on the backwards pass
2019-01-11 02:11:10 -05:00
2 changed files with 13 additions and 3 deletions

View File

@ -2,16 +2,21 @@ using NNlib: logsoftmax, logσ
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
function mse(, y)
x = sum(( .- y).^2)
return x/Tracker.tracked_eltype(x)(length(y))
end
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) / size(y, 2)
x = -sum(y .* log.() .* weight)
return x/Tracker.tracked_eltype(x)(size(y, 2))
end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
x = -sum(y .* logsoftmax(logŷ) .* weight)
return x/Tracker.tracked_eltype(x)(size(y, 2))
end
"""

View File

@ -494,3 +494,8 @@ if VERSION < v"1.1.0-DEV.548"
end
end
end
# Get the element type of an array or the element type of the inner array of a tracked array
tracked_eltype(x) = eltype(x)
tracked_eltype(x::TrackedArray) = eltype(data(x))
tracked_eltype(x::TrackedReal{T}) where T = T