Merge pull request #606 from pshashk/patch-3
Add `corrected` argument to std
This commit is contained in:
commit
f17a5acd2b
@ -40,12 +40,17 @@ but it is more numerically stable.
|
||||
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
|
||||
"""
|
||||
normalise(x::AbstractArray, dims::Int=1)
|
||||
normalise(x::AbstractArray; dims=1)
|
||||
|
||||
Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns.
|
||||
"""
|
||||
function normalise(x::AbstractArray, dims::Int=1)
|
||||
function normalise(x::AbstractArray; dims=1)
|
||||
μ′ = mean(x, dims = dims)
|
||||
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
||||
return (x .- μ′) ./ σ′
|
||||
end
|
||||
|
||||
function normalise(x::AbstractArray, dims=1)
|
||||
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
||||
normalise(x, dims = dims)
|
||||
end
|
||||
|
@ -321,9 +321,9 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected)
|
||||
_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected))
|
||||
_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected))
|
||||
|
||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
@ -178,6 +178,10 @@ end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))
|
||||
|
||||
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
||||
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
Loading…
Reference in New Issue
Block a user