From 368c29e5e31fc7920f6c164572a524c1fd8ff33f Mon Sep 17 00:00:00 2001 From: pshashk Date: Fri, 8 Feb 2019 15:23:27 +0300 Subject: [PATCH 1/7] Add `corrected` argument to std Fixes https://github.com/FluxML/Flux.jl/blob/ffe037c485e50d53aeea0df97a81cd61fcc6ee81/src/layers/stateless.jl#L49 --- src/tracker/lib/array.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 33504819..52b92cf7 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -321,9 +321,9 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) @grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs) # Hacks to get std working -Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims) -_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1)) -_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1)) +Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims), corrected::Bool = true) = _std(x,mean,dims,corrected) +_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected)) +_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected)) LinearAlgebra.norm(x::TrackedArray, p::Real = 2) = sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 From 4f6432d1335d51329f3038b31d5ddfcf829f048e Mon Sep 17 00:00:00 2001 From: pshashk Date: Fri, 8 Feb 2019 15:28:07 +0300 Subject: [PATCH 2/7] test --- test/tracker.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tracker.jl b/test/tracker.jl index fe3c0390..41f38ec4 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -178,6 +178,7 @@ end @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, dims = 1), rand(5,5)) +@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5)) @test gradtest((x, y) -> x .* y, rand(5), rand(5)) @test gradtest(dot, rand(5), rand(5)) From 37385e0dbdbe982a1557c8c49d59622201e39574 Mon Sep 17 00:00:00 2001 From: pshashk Date: Fri, 8 Feb 2019 15:43:50 +0300 Subject: [PATCH 3/7] test normalise --- test/tracker.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/tracker.jl b/test/tracker.jl index 41f38ec4..385d941a 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -180,6 +180,9 @@ end @test gradtest(x -> std(x, dims = 1), rand(5,5)) @test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5)) +@test gradtest(x -> Flux.normalise(x), rand(4,3)) +@test gradtest(x -> Flux.normalise(x, 1), rand(3,4)) + @test gradtest((x, y) -> x .* y, rand(5), rand(5)) @test gradtest(dot, rand(5), rand(5)) From 911c901294ffeb04c3a0c2ebe8f2396c24d032ef Mon Sep 17 00:00:00 2001 From: pshashk Date: Fri, 8 Feb 2019 16:00:32 +0300 Subject: [PATCH 4/7] `dims` kwarg --- src/layers/stateless.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 4b868261..262a65c0 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -40,12 +40,17 @@ but it is more numerically stable. logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ) """ - normalise(x::AbstractArray, dims::Int=1) + normalise(x::AbstractArray; dims::Int=1) Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns. """ -function normalise(x::AbstractArray, dims::Int=1) +function normalise(x::AbstractArray; dims::Int=1) μ′ = mean(x, dims = dims) σ′ = std(x, dims = dims, mean = μ′, corrected=false) return (x .- μ′) ./ σ′ end + +function normalise(x::AbstractArray, dims::Int=1) + Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise) + normalise(x, dims = dims) +end From ae10421bfed8adb02f76ca7d52d65bf944d3a0e4 Mon Sep 17 00:00:00 2001 From: pshashk Date: Fri, 8 Feb 2019 16:02:03 +0300 Subject: [PATCH 5/7] fix normalise test for dims kwarg --- test/tracker.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tracker.jl b/test/tracker.jl index 385d941a..47ce7166 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -181,7 +181,7 @@ end @test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5)) @test gradtest(x -> Flux.normalise(x), rand(4,3)) -@test gradtest(x -> Flux.normalise(x, 1), rand(3,4)) +@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4)) @test gradtest((x, y) -> x .* y, rand(5), rand(5)) @test gradtest(dot, rand(5), rand(5)) From c3e04392d83fa457de891e63a45673bf32a5e38c Mon Sep 17 00:00:00 2001 From: pshashk Date: Fri, 8 Feb 2019 16:15:37 +0300 Subject: [PATCH 6/7] drop dims type restriction --- src/layers/stateless.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 262a65c0..66309327 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -44,13 +44,13 @@ logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ) Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns. """ -function normalise(x::AbstractArray; dims::Int=1) +function normalise(x::AbstractArray; dims=1) μ′ = mean(x, dims = dims) σ′ = std(x, dims = dims, mean = μ′, corrected=false) return (x .- μ′) ./ σ′ end -function normalise(x::AbstractArray, dims::Int=1) +function normalise(x::AbstractArray, dims=1) Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise) normalise(x, dims = dims) end From b074b2491a901305f34ef5f4f7ecb9860148cb8e Mon Sep 17 00:00:00 2001 From: pshashk Date: Fri, 8 Feb 2019 21:49:53 +0300 Subject: [PATCH 7/7] fix docstring --- src/layers/stateless.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 66309327..cc1be044 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -40,7 +40,7 @@ but it is more numerically stable. logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ) """ - normalise(x::AbstractArray; dims::Int=1) + normalise(x::AbstractArray; dims=1) Normalises x to mean 0 and standard deviation 1, across the dimensions given by dims. Defaults to normalising over columns. """