fix std usage
This commit is contained in:
parent
88a265154c
commit
7057ca739e
|
@ -4,7 +4,7 @@ module Flux
|
|||
|
||||
# Zero Flux Given
|
||||
|
||||
using Juno, Requires, Reexport
|
||||
using Juno, Requires, Reexport, StatsBase
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
|
|
|
@ -233,10 +233,12 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
|||
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
using StatsBase
|
||||
|
||||
# Hacks to get std working
|
||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
StatsBase.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
StatsBase.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
||||
|
||||
LinearAlgebra.vecnorm(x::TrackedArray, p::Real = 2) =
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
||||
using StatsBase: std
|
||||
|
||||
@testset "Throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
|
|
Loading…
Reference in New Issue