Merge branch 'master' into gru
This commit is contained in:
commit
b44237468e
3
REQUIRE
3
REQUIRE
@ -3,5 +3,6 @@ DataFlow 0.2.1
|
||||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
ForwardDiff
|
||||
ForwardDiff 0.5.0
|
||||
Requires
|
||||
Adapt
|
||||
|
@ -5,6 +5,7 @@ These core layers form the foundation of almost all neural networks.
|
||||
```@docs
|
||||
Chain
|
||||
Dense
|
||||
Conv2D
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
|
@ -7,13 +7,14 @@ module Flux
|
||||
using Juno, Requires
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU,
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv2D,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
SGD, ADAM, Momentum, Nesterov, AMSGrad,
|
||||
param, params, mapleaves
|
||||
|
||||
using NNlib
|
||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
|
||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax,
|
||||
conv2d, maxpool2d, avgpool2d
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
@ -27,6 +28,7 @@ include("treelike.jl")
|
||||
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
include("layers/conv.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/normalisation.jl")
|
||||
|
||||
|
@ -23,14 +23,14 @@ end
|
||||
|
||||
function symbols()
|
||||
load()
|
||||
Symbol.(split(readstring(deps("CMUDict", "cmudict.symbols")),
|
||||
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
|
||||
"\n", keep = false))
|
||||
end
|
||||
|
||||
function rawdict()
|
||||
load()
|
||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n"))))
|
||||
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
|
||||
end
|
||||
|
||||
validword(s) = ismatch(r"^[\w\-\.]+$", s)
|
||||
|
@ -63,8 +63,10 @@ struct Dense{F,S,T}
|
||||
b::T
|
||||
end
|
||||
|
||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||
Dense(σ, param(init(out, in)), param(init(out)))
|
||||
function Dense(in::Integer, out::Integer, σ = identity;
|
||||
initW = glorot_uniform, initb = zeros)
|
||||
return Dense(σ, param(initW(out, in)), param(initb(out)))
|
||||
end
|
||||
|
||||
treelike(Dense)
|
||||
|
||||
|
33
src/layers/conv.jl
Normal file
33
src/layers/conv.jl
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
Conv2D(size, in=>out)
|
||||
Conv2d(size, in=>out, relu)
|
||||
|
||||
Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
||||
`in` and `out` specify the number of input and output channels respectively.
|
||||
|
||||
Data should be stored in HWCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct Conv2D{F,A}
|
||||
σ::F
|
||||
weight::A
|
||||
stride::Int
|
||||
pad::Int
|
||||
end
|
||||
|
||||
Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = initn, stride = 1, pad = 0) =
|
||||
Conv2D(σ, param(init(k..., ch...)), stride, pad)
|
||||
|
||||
Flux.treelike(Conv2D)
|
||||
|
||||
(c::Conv2D)(x) = c.σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad))
|
||||
|
||||
function Base.show(io::IO, l::Conv2D)
|
||||
print(io, "Conv2D((", size(l.weight, 1), ", ", size(l.weight, 2), ")")
|
||||
print(io, ", ", size(l.weight, 3), "=>", size(l.weight, 4))
|
||||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
@ -79,8 +79,8 @@ struct RNNCell{D,V}
|
||||
h::V
|
||||
end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh; init = initn) =
|
||||
RNNCell(Dense(in+out, out, σ, init = init), param(init(out)))
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) =
|
||||
RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
h = m.d(combine(x, h))
|
||||
@ -113,10 +113,10 @@ struct LSTMCell{D1,D2,V}
|
||||
h::V; c::V
|
||||
end
|
||||
|
||||
function LSTMCell(in, out; init = initn)
|
||||
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]...,
|
||||
Dense(in+out, out, tanh, init = init),
|
||||
param(init(out)), param(init(out)))
|
||||
function LSTMCell(in, out; initW = glorot_uniform, initb = zeros)
|
||||
cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]...,
|
||||
Dense(in+out, out, tanh, initW = initW, initb = initb),
|
||||
param(initW(out)), param(initW(out)))
|
||||
cell.forget.b.data .= 1
|
||||
return cell
|
||||
end
|
||||
|
@ -4,8 +4,9 @@ using NNlib: log_fast
|
||||
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
|
||||
crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) =
|
||||
-sum(y .* log_fast.(ŷ)) / size(y, 2)
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2)
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
||||
|
@ -18,7 +18,9 @@ end
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
@ -26,7 +28,7 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
|
||||
|
||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||
|
||||
import NNlib.adapt
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
|
@ -1,15 +1,24 @@
|
||||
using Juno
|
||||
using Flux.Tracker: back!
|
||||
using Flux.Tracker: back!, value
|
||||
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
||||
"""
|
||||
train!(loss, data, opt; cb = () -> ())
|
||||
train!(loss, data, opt)
|
||||
|
||||
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
|
||||
backpropagation and calls the optimizer `opt` and the callback `cb`
|
||||
(i.e. `opt()` and `cb()`).
|
||||
backpropagation and calls the optimizer `opt`.
|
||||
|
||||
Takes a callback as keyword argument `cb`. For example, this will print "training"
|
||||
every 10 seconds:
|
||||
|
||||
```julia
|
||||
Flux.train!(loss, data, opt,
|
||||
cb = throttle(() -> println("training"), 10))
|
||||
```
|
||||
|
||||
The callback can return `:stop` to interrupt the training loop.
|
||||
|
||||
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
|
||||
"""
|
||||
@ -18,10 +27,10 @@ function train!(loss, data, opt; cb = () -> ())
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
isinf(l.data[]) && error("Loss is Inf")
|
||||
isnan(l.data[]) && error("Loss is NaN")
|
||||
isinf(value(l)) && error("Loss is Inf")
|
||||
isnan(value(l)) && error("Loss is NaN")
|
||||
back!(l)
|
||||
opt()
|
||||
cb()
|
||||
cb() == :stop && break
|
||||
end
|
||||
end
|
||||
|
@ -93,7 +93,7 @@ include("back.jl")
|
||||
include("lib.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
import NNlib.adapt
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||
|
||||
|
@ -12,16 +12,17 @@ function scan(x::TrackedArray)
|
||||
return
|
||||
end
|
||||
|
||||
back(c::Call, Δ) = back(c.func, Δ, c.args...)
|
||||
back(::Call{Void}, Δ) = nothing
|
||||
back_(f, y, args...) = back(f, args...)
|
||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||
back_(::Call{Void}, y, Δ) = nothing
|
||||
|
||||
function back(x::TrackedArray, Δ)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad .+= Δ
|
||||
ref == 0 && back(x.f, x.grad)
|
||||
ref == 0 && back_(x.f, x.data, x.grad)
|
||||
else
|
||||
ref == 0 && back(x.f, Δ)
|
||||
ref == 0 && back_(x.f, x.data, Δ)
|
||||
end
|
||||
return
|
||||
end
|
||||
@ -35,6 +36,9 @@ end
|
||||
|
||||
# Interface methods
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
|
||||
function back!(x::TrackedArray, Δ)
|
||||
scan(x)
|
||||
back(x, Δ)
|
||||
|
@ -44,6 +44,12 @@ function back(::typeof(vcat), Δ, xs, ys)
|
||||
@back(ys, Δ[size(xs,1)+1:end, i...])
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||
TrackedArray(Call(reshape, xs, dims...))
|
||||
|
||||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||
@ -58,6 +64,15 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
||||
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
|
||||
|
||||
function back(::typeof(dot), Δ, xs, ys)
|
||||
@back(xs, Δ.*ys)
|
||||
@back(ys, Δ.*xs)
|
||||
end
|
||||
|
||||
# Hacks to get std working
|
||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||
@ -70,7 +85,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
|
||||
# BLAS
|
||||
|
||||
for f in :[*, Ac_mul_B].args
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
|
||||
@ -94,7 +109,12 @@ end
|
||||
|
||||
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||
@back(a, A_mul_Bt(Δ, data(b))')
|
||||
@back(b, *(data(a), Δ))
|
||||
@back(b, data(a)*Δ)
|
||||
end
|
||||
|
||||
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||
@back(a, Δ * data(b))
|
||||
@back(b, At_mul_B(data(a), Δ)')
|
||||
end
|
||||
|
||||
# Fast path for matrix-vector
|
||||
@ -109,12 +129,36 @@ end
|
||||
|
||||
# NNlib
|
||||
|
||||
import NNlib: softmax, ∇softmax
|
||||
using NNlib
|
||||
import NNlib: softmax, ∇softmax, conv2d, pool
|
||||
|
||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||
|
||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||
|
||||
# TODO: can store kwargs efficiently in namedtuples
|
||||
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
|
||||
|
||||
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
|
||||
TrackedArray(Call(_conv2d, x, w, stride, padding))
|
||||
|
||||
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
|
||||
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||
@back(w, NNlib.conv2d_grad_w(data(x), data(w), Δ; stride = stride, padding = pad))
|
||||
end
|
||||
|
||||
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
|
||||
|
||||
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
|
||||
TrackedArray(Call(_pool, x, window, padding, mode))
|
||||
|
||||
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
|
||||
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
|
||||
|
||||
# Broadcasting
|
||||
|
||||
using ForwardDiff: Dual, partials
|
||||
|
@ -19,4 +19,4 @@ function ngradient(f, xs::AbstractArray...)
|
||||
return grads
|
||||
end
|
||||
|
||||
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-6))
|
||||
gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), gradient(f, xs...), rtol = 1e-5))
|
||||
|
30
src/utils.jl
30
src/utils.jl
@ -1,8 +1,8 @@
|
||||
# Arrays
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
|
||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
@ -93,13 +93,14 @@ but if you'd like to disable the execution on the leading edge, pass
|
||||
function throttle(f, timeout; leading=true, trailing=false)
|
||||
cooldown = true
|
||||
later = nothing
|
||||
result = nothing
|
||||
|
||||
function throttled(args...; kwargs...)
|
||||
yield()
|
||||
|
||||
if cooldown
|
||||
if leading
|
||||
f(args...; kwargs...)
|
||||
result = f(args...; kwargs...)
|
||||
else
|
||||
later = () -> f(args...; kwargs...)
|
||||
end
|
||||
@ -114,9 +115,28 @@ function throttle(f, timeout; leading=true, trailing=false)
|
||||
cooldown = true
|
||||
end
|
||||
elseif trailing
|
||||
later = () -> f(args...; kwargs...)
|
||||
later = () -> (result = f(args...; kwargs...))
|
||||
end
|
||||
|
||||
nothing
|
||||
return result
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
J = jacobian(m,x)
|
||||
|
||||
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
|
||||
"""
|
||||
function jacobian(m,x)
|
||||
xp = param(x)
|
||||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i]) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
xp.grad .*= 0 # Reset gradient accumulator
|
||||
end
|
||||
J'
|
||||
end
|
||||
|
@ -1,3 +1,8 @@
|
||||
using Flux.Data
|
||||
using Base.Test
|
||||
|
||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||
|
||||
@test length(CMUDict.phones()) == 39
|
||||
|
||||
@test length(CMUDict.symbols()) == 84
|
||||
|
26
test/layers/stateless.jl
Normal file
26
test/layers/stateless.jl
Normal file
@ -0,0 +1,26 @@
|
||||
using Flux: onehotbatch, mse, crossentropy
|
||||
|
||||
@testset "losses" begin
|
||||
# First, regression-style y's
|
||||
y = [1, 1, 0, 0]
|
||||
y_hat = [.9, .1, .1, .9]
|
||||
|
||||
@testset "mse" begin
|
||||
@test mse(y_hat, y) ≈ (.1^2 + .9^2)/2
|
||||
end
|
||||
|
||||
# Now onehot y's
|
||||
y = onehotbatch([1, 1, 0, 0], 0:1)
|
||||
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
||||
y_logloss = 1.203972804325936
|
||||
|
||||
@testset "crossentropy" begin
|
||||
@test crossentropy(y_hat, y) ≈ y_logloss
|
||||
end
|
||||
|
||||
@testset "weighted_crossentropy" begin
|
||||
@test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss
|
||||
@test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2
|
||||
@test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199
|
||||
end
|
||||
end
|
@ -15,3 +15,15 @@ using Flux.Tracker
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Training Loop" begin
|
||||
i = 0
|
||||
l = param(1)
|
||||
|
||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
Iterators.repeated((), 100),
|
||||
()->(),
|
||||
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
||||
|
||||
@test 3 < i < 50
|
||||
end
|
||||
|
@ -5,6 +5,8 @@ using Flux, Base.Test
|
||||
include("utils.jl")
|
||||
include("tracker.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
|
||||
end
|
||||
|
@ -1,5 +1,6 @@
|
||||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: gradcheck
|
||||
using NNlib
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
@ -10,6 +11,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
|
||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||
|
||||
@ -37,9 +39,15 @@ end
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv2d, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
|
||||
|
||||
end #testset
|
||||
|
@ -1,4 +1,4 @@
|
||||
using Flux: throttle
|
||||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
||||
|
||||
@testset "Throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
@ -47,3 +47,35 @@ using Flux: throttle
|
||||
@test a == [1, 3]
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Jacobian" begin
|
||||
A = param(randn(2,2))
|
||||
x = randn(2)
|
||||
m(x) = A*x
|
||||
y = m(x)
|
||||
J = jacobian(m,x)
|
||||
@test J ≈ A.data
|
||||
end
|
||||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
srand(0)
|
||||
# initn() should yield a kernel with stddev ~= 1e-2
|
||||
v = initn(10, 10)
|
||||
@test std(v) > 0.9*1e-2
|
||||
@test std(v) < 1.1*1e-2
|
||||
|
||||
# glorot_uniform should yield a kernel with stddev ~= sqrt(6/(n_in + n_out)),
|
||||
# and glorot_normal should yield a kernel with stddev != 2/(n_in _ n_out)
|
||||
for (n_in, n_out) in [(100, 100), (100, 400)]
|
||||
v = glorot_uniform(n_in, n_out)
|
||||
@test minimum(v) > -1.1*sqrt(6/(n_in + n_out))
|
||||
@test minimum(v) < -0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) > 0.9*sqrt(6/(n_in + n_out))
|
||||
@test maximum(v) < 1.1*sqrt(6/(n_in + n_out))
|
||||
|
||||
v = glorot_normal(n_in, n_out)
|
||||
@test std(v) > 0.9*sqrt(2/(n_in + n_out))
|
||||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user