Add corrected
argument to std
Fixes ffe037c485/src/layers/stateless.jl (L49)
This commit is contained in:
parent
ffe037c485
commit
368c29e5e3
@ -321,9 +321,9 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
|||||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||||
|
|
||||||
# Hacks to get std working
|
# Hacks to get std working
|
||||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected)
|
||||||
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected))
|
||||||
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected))
|
||||||
|
|
||||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||||
|
Loading…
Reference in New Issue
Block a user