Merge branch 'master' into gru

This commit is contained in:
Mike J Innes 2018-01-10 13:59:33 +00:00
commit b44237468e
21 changed files with 243 additions and 39 deletions

View File

@ -3,5 +3,6 @@ DataFlow 0.2.1
Juno Juno
MacroTools 0.3.3 MacroTools 0.3.3
NNlib NNlib
ForwardDiff ForwardDiff 0.5.0
Requires Requires
Adapt

View File

@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks.
```@docs ```@docs
Chain Chain
Dense Dense
Conv2D
``` ```
## Recurrent Layers ## Recurrent Layers

View File

@ -7,13 +7,14 @@ module Flux
using Juno, Requires using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, GRU, export Chain, Dense, RNN, LSTM, GRU, Conv2D,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
SGD, ADAM, Momentum, Nesterov, AMSGrad, SGD, ADAM, Momentum, Nesterov, AMSGrad,
param, params, mapleaves param, params, mapleaves
using NNlib using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
using .Tracker using .Tracker
@ -27,6 +28,7 @@ include("treelike.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/basic.jl") include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl") include("layers/recurrent.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")

View File

@ -23,14 +23,14 @@ end
function symbols() function symbols()
load() load()
Symbol.(split(readstring(deps("CMUDict", "cmudict.symbols")), Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
"\n", keep = false)) "\n", keep = false))
end end
function rawdict() function rawdict()
load() load()
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
end end
validword(s) = ismatch(r"^[\w\-\.]+$", s) validword(s) = ismatch(r"^[\w\-\.]+$", s)

View File

@ -63,8 +63,10 @@ struct Dense{F,S,T}
b::T b::T
end end
Dense(in::Integer, out::Integer, σ = identity; init = initn) = function Dense(in::Integer, out::Integer, σ = identity;
Dense(σ, param(init(out, in)), param(init(out))) initW = glorot_uniform, initb = zeros)
return Dense(σ, param(initW(out, in)), param(initb(out)))
end
treelike(Dense) treelike(Dense)

33
src/layers/conv.jl Normal file
View File

@ -0,0 +1,33 @@
"""
Conv2D(size, in=>out)
Conv2d(size, in=>out, relu)
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in HWCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`.
"""
struct Conv2D{F,A}
σ::F
weight::A
stride::Int
pad::Int
end
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = initn, stride = 1, pad = 0) =
Conv2D(σ, param(init(k..., ch...)), stride, pad)
Flux.treelike(Conv2D)
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad))
function Base.show(io::IO, l::Conv2D)
print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")")
print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

View File

@ -79,8 +79,8 @@ struct RNNCell{D,V}
h::V h::V
end end
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) = RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
RNNCell(Dense(in+out, out, σ, init = init), param(init(out))) RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
function (m::RNNCell)(h, x) function (m::RNNCell)(h, x)
h = m.d(combine(x, h)) h = m.d(combine(x, h))
@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V}
h::V; c::V h::V; c::V
end end
function LSTMCell(in, out; init = initn) function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]..., cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
Dense(in+out, out, tanh, init = init), Dense(in+out, out, tanh, initW = initW, initb = initb),
param(init(out)), param(init(out))) param(initW(out)), param(initW(out)))
cell.forget.b.data .= 1 cell.forget.b.data .= 1
return cell return cell
end end

View File

@ -4,8 +4,9 @@ using NNlib: log_fast
mse(, y) = sum(( .- y).^2)/length(y) mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) = function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log_fast.()) / size(y, 2) return -sum(y .* log_fast.() .* weight) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)

View File

@ -18,7 +18,9 @@ end
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
@ -26,7 +28,7 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
import NNlib.adapt import Adapt.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))

View File

@ -1,15 +1,24 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: back!, value
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
""" """
train!(loss, data, opt; cb = () -> ()) train!(loss, data, opt)
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt` and the callback `cb` backpropagation and calls the optimizer `opt`.
(i.e. `opt()` and `cb()`).
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can return `:stop` to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
@ -18,10 +27,10 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
l = loss(d...) l = loss(d...)
isinf(l.data[]) && error("Loss is Inf") isinf(value(l)) && error("Loss is Inf")
isnan(l.data[]) && error("Loss is NaN") isnan(value(l)) && error("Loss is NaN")
back!(l) back!(l)
opt() opt()
cb() cb() == :stop && break
end end
end end

View File

@ -93,7 +93,7 @@ include("back.jl")
include("lib.jl") include("lib.jl")
include("numeric.jl") include("numeric.jl")
import NNlib.adapt import Adapt.adapt
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad)) adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))

View File

@ -12,16 +12,17 @@ function scan(x::TrackedArray)
return return
end end
back(c::Call, Δ) = back(c.func, Δ, c.args...) back_(f, y, args...) = back(f, args...)
back(::Call{Void}, Δ) = nothing back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
function back(x::TrackedArray, Δ) function back(x::TrackedArray, Δ)
ref = x.ref -= 1 ref = x.ref -= 1
if isdefined(x, :grad) if isdefined(x, :grad)
x.grad .+= Δ x.grad .+= Δ
ref == 0 && back(x.f, x.grad) ref == 0 && back_(x.f, x.data, x.grad)
else else
ref == 0 && back(x.f, Δ) ref == 0 && back_(x.f, x.data, Δ)
end end
return return
end end
@ -35,6 +36,9 @@ end
# Interface methods # Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
function back!(x::TrackedArray, Δ) function back!(x::TrackedArray, Δ)
scan(x) scan(x)
back(x, Δ) back(x, Δ)

View File

@ -44,6 +44,12 @@ function back(::typeof(vcat), Δ, xs, ys)
@back(ys, Δ[size(xs,1)+1:end, i...]) @back(ys, Δ[size(xs,1)+1:end, i...])
end end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
TrackedArray(Call(reshape, xs, dims...))
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
# Reductions # Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
@ -58,6 +64,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
function back(::typeof(dot), Δ, xs, ys)
@back(xs, Δ.*ys)
@back(ys, Δ.*xs)
end
# Hacks to get std working # Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) = Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
@ -70,7 +85,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
# BLAS # BLAS
for f in :[*, Ac_mul_B].args for f in :[*, Ac_mul_B, A_mul_Bc].args
@eval begin @eval begin
import Base.$f import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
@ -94,7 +109,12 @@ end
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))') @back(a, A_mul_Bt(Δ, data(b))')
@back(b, *(data(a), Δ)) @back(b, data(a)*Δ)
end
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, Δ * data(b))
@back(b, At_mul_B(data(a), Δ)')
end end
# Fast path for matrix-vector # Fast path for matrix-vector
@ -109,12 +129,36 @@ end
# NNlib # NNlib
import NNlib: softmax, ∇softmax using NNlib
import NNlib: softmax, ∇softmax, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
end
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
TrackedArray(Call(_pool, x, window, padding, mode))
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
# Broadcasting # Broadcasting
using ForwardDiff: Dual, partials using ForwardDiff: Dual, partials

View File

@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...)
return grads return grads
end end
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6)) gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5))

View File

@ -1,8 +1,8 @@
# Arrays # Arrays
initn(dims...) = randn(dims...)/100 initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
flatten(xs) = reshape(xs, size(xs, 1), :) glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
@ -93,13 +93,14 @@ but if you'd like to disable the execution on the leading edge, pass
function throttle(f, timeout; leading=true, trailing=false) function throttle(f, timeout; leading=true, trailing=false)
cooldown = true cooldown = true
later = nothing later = nothing
result = nothing
function throttled(args...; kwargs...) function throttled(args...; kwargs...)
yield() yield()
if cooldown if cooldown
if leading if leading
f(args...; kwargs...) result = f(args...; kwargs...)
else else
later = () -> f(args...; kwargs...) later = () -> f(args...; kwargs...)
end end
@ -114,9 +115,28 @@ function throttle(f, timeout; leading=true, trailing=false)
cooldown = true cooldown = true
end end
elseif trailing elseif trailing
later = () -> f(args...; kwargs...) later = () -> (result = f(args...; kwargs...))
end end
nothing return result
end end
end end
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(n,k)
for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator
end
J'
end

View File

@ -1,3 +1,8 @@
using Flux.Data using Flux.Data
using Base.Test
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args @test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
@test length(CMUDict.phones()) == 39
@test length(CMUDict.symbols()) == 84

26
test/layers/stateless.jl Normal file
View File

@ -0,0 +1,26 @@
using Flux: onehotbatch, mse, crossentropy
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
y_hat = [.9, .1, .1, .9]
@testset "mse" begin
@test mse(y_hat, y) (.1^2 + .9^2)/2
end
# Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1)
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
y_logloss = 1.203972804325936
@testset "crossentropy" begin
@test crossentropy(y_hat, y) y_logloss
end
@testset "weighted_crossentropy" begin
@test crossentropy(y_hat, y, weight = ones(2)) y_logloss
@test crossentropy(y_hat, y, weight = [.5, .5]) y_logloss/2
@test crossentropy(y_hat, y, weight = [2, .5]) 1.5049660054074199
end
end

View File

@ -15,3 +15,15 @@ using Flux.Tracker
@test Flux.mse(w, w) < 0.01 @test Flux.mse(w, w) < 0.01
end end
end end
@testset "Training Loop" begin
i = 0
l = param(1)
Flux.train!(() -> (sleep(0.1); i += 1; l),
Iterators.repeated((), 100),
()->(),
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
@test 3 < i < 50
end

View File

@ -5,6 +5,8 @@ using Flux, Base.Test
include("utils.jl") include("utils.jl")
include("tracker.jl") include("tracker.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl") include("optimise.jl")
include("data.jl")
end end

View File

@ -1,5 +1,6 @@
using Flux.Tracker, Base.Test, NNlib using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: gradcheck using Flux.Tracker: gradcheck
using NNlib
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...) gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...) gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@ -10,6 +11,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@ -37,9 +39,15 @@ end
@test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(rand(5)) do x @test gradtest(rand(5)) do x
y = x.^2 y = x.^2
2y + x 2y + x
end end
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
end #testset end #testset

View File

@ -1,4 +1,4 @@
using Flux: throttle using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
@testset "Throttle" begin @testset "Throttle" begin
@testset "default behaviour" begin @testset "default behaviour" begin
@ -47,3 +47,35 @@ using Flux: throttle
@test a == [1, 3] @test a == [1, 3]
end end
end end
@testset "Jacobian" begin
A = param(randn(2,2))
x = randn(2)
m(x) = A*x
y = m(x)
J = jacobian(m,x)
@test J A.data
end
@testset "Initialization" begin
# Set random seed so that these tests don't fail randomly
srand(0)
# initn() should yield a kernel with stddev ~= 1e-2
v = initn(10, 10)
@test std(v) > 0.9*1e-2
@test std(v) < 1.1*1e-2
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
for (n_in, n_out) in [(100, 100), (100, 400)]
v = glorot_uniform(n_in, n_out)
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
v = glorot_normal(n_in, n_out)
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
end
end