fix std usage

This commit is contained in:
Mike J Innes 2018-06-20 15:18:07 +01:00
parent 88a265154c
commit 7057ca739e
4 changed files with 7 additions and 3 deletions

View File

@ -4,7 +4,7 @@ module Flux
# Zero Flux Given
using Juno, Requires, Reexport
using Juno, Requires, Reexport, StatsBase
using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv,

View File

@ -233,10 +233,12 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
using StatsBase
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
StatsBase.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
StatsBase.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
LinearAlgebra.vecnorm(x::TrackedArray, p::Real = 2) =

View File

@ -1,6 +1,7 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
using StatsBase
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...)

View File

@ -1,4 +1,5 @@
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
using StatsBase: std
@testset "Throttle" begin
@testset "default behaviour" begin