From 2cb94981a0176f070eb2dec31c00ef125613ce3f Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 27 Oct 2017 12:05:37 +0100 Subject: [PATCH] gpu-ready log --- src/Flux.jl | 1 + src/layers/stateless.jl | 4 +-- src/numeric.jl | 80 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 src/numeric.jl diff --git a/src/Flux.jl b/src/Flux.jl index acefff19..ce3861e5 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,6 +21,7 @@ include("optimise/Optimise.jl") using .Optimise include("utils.jl") +include("numeric.jl") include("onehot.jl") include("treelike.jl") diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 3931c216..56d18349 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -3,12 +3,12 @@ mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) = - -sum(y .* log.(ŷ)) / size(y, 2) + -sum(y .* log_fast.(ŷ)) / size(y, 2) @deprecate logloss(x, y) crossentropy(x, y) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) logŷ = logŷ .- maximum(logŷ, 1) - ypred = logŷ .- log.(sum(exp.(logŷ), 1)) + ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1)) -sum(y .* ypred) / size(y, 2) end diff --git a/src/numeric.jl b/src/numeric.jl new file mode 100644 index 00000000..9c444043 --- /dev/null +++ b/src/numeric.jl @@ -0,0 +1,80 @@ +using Base.Math: @horner, significand_bits, exponent_raw_max, exponent_bias + +if VERSION < v"0.7.0-DEV.1430" + using Base.Math.fpinttype +else + using Base.uinttype +end + +# log_fast from +# https://github.com/musm/SLEEF.jl/blob/c9dcd2eb090d69ec40790f19798c5fef2aba2616/src/log.jl + +const MLN2 = 6.931471805599453094172321214581765680755001343602552541206800094933936219696955e-01 # log(2) + +@inline float2integer(d::Float64) = (reinterpret(Int64, d) >> significand_bits(Float64)) % Int +@inline float2integer(d::Float32) = (reinterpret(Int32, d) >> significand_bits(Float32)) % Int + +@inline function ilogb2k(d::T) where {T<:Union{Float32,Float64}} + (float2integer(d) & exponent_raw_max(T)) - exponent_bias(T) +end + +@inline function ldexp3k(x::T, e::Int) where {T<:Union{Float32,Float64}} + if VERSION < v"0.7.0-DEV.1430" + reinterpret(T, reinterpret(Unsigned, x) + (Int64(e) << significand_bits(T)) % fpinttype(T)) + else + reinterpret(T, reinterpret(Unsigned, x) + (Int64(e) << significand_bits(T)) % uinttype(T)) + end +end + +""" + log_fast(x) +Compute the natural logarithm of `x`. The inverse of the natural logarithm is +the natural expoenential function `exp(x)` +""" +function log_fast end + +let +global log_fast + +c8d = 0.153487338491425068243146 +c7d = 0.152519917006351951593857 +c6d = 0.181863266251982985677316 +c5d = 0.222221366518767365905163 +c4d = 0.285714294746548025383248 +c3d = 0.399999999950799600689777 +c2d = 0.6666666666667778740063 +c1d = 2.0 + +c5f = 0.2392828464508056640625f0 +c4f = 0.28518211841583251953125f0 +c3f = 0.400005877017974853515625f0 +c2f = 0.666666686534881591796875f0 +c1f = 2f0 + +global @inline log_fast_kernel(x::Float64) = @horner x c1d c2d c3d c4d c5d c6d c7d c8d +global @inline log_fast_kernel(x::Float32) = @horner x c1f c2f c3f c4f c5f + +function log_fast(d::T) where {T<:Union{Float32,Float64}} + o = d < realmin(T) + o && (d *= T(Int64(1) << 32) * T(Int64(1) << 32)) + + e = ilogb2k(d * T(1.0/0.75)) + m = ldexp3k(d, -e) + o && (e -= 64) + + x = (m - 1) / (m + 1) + x2 = x * x + + t = log_fast_kernel(x2) + + x = x * t + T(MLN2) * e + + isinf(d) && (x = T(Inf)) + (d < 0 || isnan(d)) && (x = T(NaN)) + d == 0 && (x = -T(Inf)) + + return x +end +end + +log_fast(x::Union{Int32,Int64}) = log_fast(float(x))