Merge branch 'master' into cudnn_batchnorm
This commit is contained in:
commit
8bea60d980
|
@ -5,3 +5,4 @@ docs/build/
|
|||
docs/site/
|
||||
docs/flux.css
|
||||
deps
|
||||
Manifest.toml
|
||||
|
|
|
@ -5,10 +5,15 @@ os:
|
|||
# - osx
|
||||
julia:
|
||||
- 0.7
|
||||
- 1.0
|
||||
- nightly
|
||||
# uncomment the following lines to override the default test script
|
||||
# script:
|
||||
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||
matrix:
|
||||
allow_failures:
|
||||
- julia: nightly
|
||||
after_success:
|
||||
- julia -e 'Pkg.add("Documenter")'
|
||||
- julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||
- julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")'
|
||||
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||
|
|
4
REQUIRE
4
REQUIRE
|
@ -1,10 +1,10 @@
|
|||
julia 0.7-
|
||||
julia 0.7
|
||||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
Requires
|
||||
Adapt
|
||||
GZip
|
||||
CodecZlib
|
||||
Colors
|
||||
ZipFile
|
||||
AbstractTrees
|
||||
|
|
|
@ -26,6 +26,6 @@ deploydocs(
|
|||
repo = "github.com/FluxML/Flux.jl.git",
|
||||
target = "build",
|
||||
osname = "linux",
|
||||
julia = "0.6",
|
||||
julia = "1.0",
|
||||
deps = nothing,
|
||||
make = nothing)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
||||
|
||||
```
|
||||
julia> using Flux: onehot
|
||||
julia> using Flux: onehot, onecold
|
||||
|
||||
julia> onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
|
@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
|
|||
true
|
||||
```
|
||||
|
||||
The inverse is `argmax` (which can take a general probability distribution, as well as just booleans).
|
||||
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
|
||||
|
||||
```julia
|
||||
julia> argmax(ans, [:a, :b, :c])
|
||||
julia> onecold(ans, [:a, :b, :c])
|
||||
:c
|
||||
|
||||
julia> argmax([true, false, false], [:a, :b, :c])
|
||||
julia> onecold([true, false, false], [:a, :b, :c])
|
||||
:a
|
||||
|
||||
julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||
:c
|
||||
```
|
||||
|
||||
## Batches
|
||||
|
||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches.
|
||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
||||
|
||||
```julia
|
||||
julia> using Flux: onehotbatch
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
# Flux: The Julia Machine Learning Library
|
||||
|
||||
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. The whole stack is implemented in clean Julia code (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)) and any part can be tweaked to your liking.
|
||||
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles:
|
||||
|
||||
# Installation
|
||||
* **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work – and be fast.
|
||||
* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When it doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own.
|
||||
* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models.
|
||||
|
||||
Install [Julia 0.6.0 or later](https://julialang.org/downloads/), if you haven't already.
|
||||
## Installation
|
||||
|
||||
```julia
|
||||
Pkg.add("Flux")
|
||||
# Optional but recommended
|
||||
Pkg.update() # Keep your packages up to date
|
||||
Pkg.test("Flux") # Check things installed correctly
|
||||
```
|
||||
Download [Julia 1.0](https://julialang.org/) or later, if you haven't already. You can add Flux from using Julia's package manager, by typing `] add Flux` in the Julia prompt.
|
||||
|
||||
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
|
||||
If you have CUDA you can also run `] add CuArrays` to get GPU support; see [here](gpu.md) for more details.
|
||||
|
||||
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.
|
||||
## Learning Flux
|
||||
|
||||
There are several different ways to learn Flux. If you just want to get started writing models, the [model zoo](https://github.com/FluxML/model-zoo/) gives good starting points for many common ones. This documentation provides a reference to all of Flux's APIs, as well as a from-scratch introduction to Flux's take on models and how they work. Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts.
|
||||
|
|
|
@ -172,7 +172,7 @@ using Flux
|
|||
|
||||
layers = [Dense(10, 5, σ), Dense(5, 2), softmax]
|
||||
|
||||
model(x) = foldl((x, m) -> m(x), x, layers)
|
||||
model(x) = foldl((x, m) -> m(x), layers, init = x)
|
||||
|
||||
model(rand(10)) # => 2-element vector
|
||||
```
|
||||
|
|
|
@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks.
|
|||
Chain
|
||||
Dense
|
||||
Conv
|
||||
MaxPool
|
||||
MeanPool
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Regularisation
|
||||
|
||||
Applying regularisation to model parameters is straightforward. We just need to
|
||||
apply an appropriate regulariser, such as `vecnorm`, to each model parameter and
|
||||
apply an appropriate regulariser, such as `norm`, to each model parameter and
|
||||
add the result to the overall loss.
|
||||
|
||||
For example, say we have a simple regression.
|
||||
|
@ -15,12 +15,12 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
|
|||
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
|
||||
|
||||
```julia
|
||||
penalty() = vecnorm(m.W) + vecnorm(m.b)
|
||||
penalty() = norm(m.W) + norm(m.b)
|
||||
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
|
||||
```
|
||||
|
||||
When working with layers, Flux provides the `params` function to grab all
|
||||
parameters at once. We can easily penalise everything with `sum(vecnorm, params)`.
|
||||
parameters at once. We can easily penalise everything with `sum(norm, params)`.
|
||||
|
||||
```julia
|
||||
julia> params(m)
|
||||
|
@ -28,7 +28,7 @@ julia> params(m)
|
|||
param([0.355408 0.533092; … 0.430459 0.171498])
|
||||
param([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
julia> sum(vecnorm, params(m))
|
||||
julia> sum(norm, params(m))
|
||||
26.01749952921026 (tracked)
|
||||
```
|
||||
|
||||
|
@ -40,7 +40,7 @@ m = Chain(
|
|||
Dense(128, 32, relu),
|
||||
Dense(32, 10), softmax)
|
||||
|
||||
loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
|
||||
loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m))
|
||||
|
||||
loss(rand(28^2), rand(10))
|
||||
```
|
||||
|
@ -57,6 +57,6 @@ julia> activations(c, rand(10))
|
|||
param([0.0330606, -0.456104])
|
||||
param([0.61991, 0.38009])
|
||||
|
||||
julia> sum(vecnorm, ans)
|
||||
julia> sum(norm, ans)
|
||||
2.639678767773633 (tracked)
|
||||
```
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
__precompile__()
|
||||
|
||||
module Flux
|
||||
|
||||
# Zero Flux Given
|
||||
|
@ -7,12 +5,11 @@ module Flux
|
|||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
using NNlib: @fix
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
module CUDA
|
||||
|
||||
using CuArrays
|
||||
using ..CuArrays
|
||||
|
||||
if CuArrays.cudnn_available()
|
||||
include("curnn.jl")
|
||||
|
|
|
@ -2,6 +2,8 @@ using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
|||
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
|
||||
import ..Flux: data
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Nothing}
|
||||
states::CuVector{UInt8}
|
||||
|
@ -18,8 +20,9 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
||||
finalizer(desc, x ->
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x))
|
||||
finalizer(desc) do x
|
||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
return desc
|
||||
end
|
||||
|
||||
|
|
|
@ -11,6 +11,8 @@ function __init__()
|
|||
end
|
||||
|
||||
include("mnist.jl")
|
||||
export MNIST
|
||||
|
||||
include("cmudict.jl")
|
||||
using .CMUDict
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ function load()
|
|||
return
|
||||
end
|
||||
end
|
||||
info("Downloading CMUDict dataset")
|
||||
@info "Downloading CMUDict dataset"
|
||||
mkpath(deps("cmudict"))
|
||||
for x in suffixes
|
||||
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||
|
|
|
@ -1,11 +1,17 @@
|
|||
module MNIST
|
||||
|
||||
using GZip, Colors
|
||||
using CodecZlib, Colors
|
||||
|
||||
const Gray = Colors.Gray{Colors.N0f8}
|
||||
|
||||
const dir = joinpath(@__DIR__, "../../deps/mnist")
|
||||
|
||||
function gzopen(f, file)
|
||||
open(file) do io
|
||||
f(GzipDecompressorStream(io))
|
||||
end
|
||||
end
|
||||
|
||||
function load()
|
||||
mkpath(dir)
|
||||
cd(dir) do
|
||||
|
@ -14,10 +20,10 @@ function load()
|
|||
"t10k-images-idx3-ubyte",
|
||||
"t10k-labels-idx1-ubyte"]
|
||||
isfile(file) && continue
|
||||
info("Downloading MNIST dataset")
|
||||
@info "Downloading MNIST dataset"
|
||||
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
||||
open(file, "w") do io
|
||||
write(io, GZip.open(read, "$file.gz"))
|
||||
write(io, gzopen(read, "$file.gz"))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -49,7 +55,7 @@ function labelheader(io::IO)
|
|||
end
|
||||
|
||||
function rawimage(io::IO)
|
||||
img = Array{Gray}(NCOLS, NROWS)
|
||||
img = Array{Gray}(undef, NCOLS, NROWS)
|
||||
for i in 1:NCOLS, j in 1:NROWS
|
||||
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
|
||||
end
|
||||
|
|
|
@ -5,7 +5,7 @@ using ..Data: deps
|
|||
|
||||
function load()
|
||||
isfile(deps("sentiment.zip")) || return
|
||||
info("Downloading sentiment treebank dataset")
|
||||
@info "Downloading sentiment treebank dataset"
|
||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"))
|
||||
end
|
||||
|
@ -14,7 +14,7 @@ getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
|||
|
||||
function getfile(name)
|
||||
r = ZipFile.Reader(deps("sentiment.zip"))
|
||||
text = readstring(getfile(r, "trees/$name"))
|
||||
text = read(getfile(r, "trees/$name"), String)
|
||||
close(r)
|
||||
return text
|
||||
end
|
||||
|
@ -29,12 +29,12 @@ function parsetree(s)
|
|||
s = replace(s, r"\$", s -> "\\\$")
|
||||
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
||||
s = replace(s, " ", ", ")
|
||||
return totree(parse(s))
|
||||
return totree(Meta.parse(s))
|
||||
end
|
||||
|
||||
function gettrees(name)
|
||||
load()
|
||||
ss = split(getfile("$name.txt"), '\n', keep = false)
|
||||
ss = split(getfile("$name.txt"), '\n', keepempty = false)
|
||||
return parsetree.(ss)
|
||||
end
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@ struct Chain
|
|||
Chain(xs...) = new([xs...])
|
||||
end
|
||||
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||
@forward Chain.layers Base.start, Base.next, Base.done
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
|
||||
@forward Chain.layers Base.iterate
|
||||
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
|
@ -38,7 +38,7 @@ function Base.show(io::IO, c::Chain)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
|
||||
|
||||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
|
@ -77,7 +77,7 @@ end
|
|||
|
||||
function (a::Dense)(x)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
@fix σ.(W*x .+ b)
|
||||
σ.(W*x .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Dense)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using NNlib: conv
|
||||
|
||||
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
|
||||
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
@ -28,11 +28,11 @@ end
|
|||
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
|
||||
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zero(ch[2])), σ,
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike Conv
|
||||
|
@ -50,3 +50,48 @@ function Base.show(io::IO, l::Conv)
|
|||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
MaxPool(k)
|
||||
|
||||
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct MaxPool{N}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
end
|
||||
|
||||
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||
|
||||
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::MaxPool)
|
||||
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
end
|
||||
|
||||
"""
|
||||
MeanPool(k)
|
||||
|
||||
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
"""
|
||||
struct MeanPool{N}
|
||||
k::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
stride::NTuple{N,Int}
|
||||
end
|
||||
|
||||
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||
|
||||
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||
|
||||
function Base.show(io::IO, m::MeanPool)
|
||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||
end
|
||||
|
|
|
@ -141,8 +141,8 @@ function (BN::BatchNorm)(x)
|
|||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = ((1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), dims = (axes...)))
|
||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* squeeze(data(σ²), dims = (axes...)) .* m ./ (m - 1))
|
||||
BN.μ = ((1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...)))
|
||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* dropdims(data(σ²), dims = (axes...)) .* m ./ (m - 1))
|
||||
end
|
||||
|
||||
ϵ = convert(T, BN.ϵ)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
gate(h, n) = (1:h) + h*(n-1)
|
||||
gate(h, n) = (1:h) .+ h*(n-1)
|
||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||
|
||||
|
@ -84,7 +84,7 @@ end
|
|||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(zeros(out)), param(initn(out)))
|
||||
param(zeros(out)), param(init(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
|
@ -122,14 +122,13 @@ end
|
|||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)),
|
||||
param(initn(out)), param(initn(out)))
|
||||
cell.b.data[gate(out, 2)] = 1
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
param(init(out)), param(init(out)))
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
|
||||
function (m::LSTMCell)(h_, x)
|
||||
h, c = h_ # TODO: nicer syntax on 0.7
|
||||
function (m::LSTMCell)((h, c), x)
|
||||
b, o = m.b, size(h, 1)
|
||||
g = m.Wi*x .+ m.Wh*h .+ b
|
||||
input = σ.(gate(g, o, 1))
|
||||
|
@ -170,7 +169,7 @@ end
|
|||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zero(out*3)), param(initn(out)))
|
||||
param(zeros(out*3)), param(init(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
|
|
@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
|
|||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
@ -47,7 +47,7 @@ logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
|||
Normalise each column of `x` to mean 0 and standard deviation 1.
|
||||
"""
|
||||
function normalise(x::AbstractVecOrMat)
|
||||
μ′ = mean(x, 1)
|
||||
σ′ = std(x, 1, mean = μ′)
|
||||
μ′ = mean(x, dims = 1)
|
||||
σ′ = std(x, dims = 1, mean = μ′)
|
||||
return (x .- μ′) ./ σ′
|
||||
end
|
||||
|
|
|
@ -33,8 +33,9 @@ import Adapt.adapt
|
|||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import CuArrays: CuArray, cudaconvert
|
||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
end
|
||||
|
||||
|
@ -53,11 +54,15 @@ end
|
|||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
|
||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
||||
labels[findfirst(y, maximum(y))]
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||
|
||||
argmax(y::AbstractMatrix, l...) =
|
||||
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
||||
onecold(y::AbstractMatrix, labels...) =
|
||||
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||
|
||||
function argmax(xs...)
|
||||
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||
return onecold(xs...)
|
||||
end
|
||||
|
||||
# Ambiguity hack
|
||||
|
||||
|
|
|
@ -2,14 +2,14 @@ module Optimise
|
|||
|
||||
export train!,
|
||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δ::T
|
||||
end
|
||||
|
||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zero(x))
|
||||
Param(x::AbstractArray) = Param(x, zero(x))
|
||||
|
||||
include("optimisers.jl")
|
||||
include("interface.jl")
|
||||
|
@ -17,6 +17,7 @@ include("train.jl")
|
|||
|
||||
using Flux.Tracker: TrackedArray
|
||||
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
Param(x::TrackedArray) = Param(x.data, x.grad)
|
||||
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
|
||||
end
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using Juno
|
||||
using Flux.Tracker: back!
|
||||
import Base.depwarn
|
||||
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
@ -14,6 +15,25 @@ macro interrupts(ex)
|
|||
end)
|
||||
end
|
||||
|
||||
struct StopException <: Exception end
|
||||
"""
|
||||
stop()
|
||||
|
||||
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||
This would trigger the train loop to stop and exit.
|
||||
|
||||
```julia
|
||||
# Example callback:
|
||||
|
||||
cb = function ()
|
||||
accuracy() > 0.9 && Flux.stop()
|
||||
end
|
||||
```
|
||||
"""
|
||||
function stop()
|
||||
throw(StopException())
|
||||
end
|
||||
|
||||
"""
|
||||
train!(loss, data, opt)
|
||||
|
||||
|
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
|
|||
cb = runall(cb)
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
try
|
||||
l = loss(d...)
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
end
|
||||
catch ex
|
||||
if ex isa StopException
|
||||
break
|
||||
else
|
||||
rethrow(ex)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -59,7 +90,7 @@ hello
|
|||
"""
|
||||
macro epochs(n, ex)
|
||||
:(@progress for i = 1:$(esc(n))
|
||||
info("Epoch $i")
|
||||
@info "Epoch $i"
|
||||
$(esc(ex))
|
||||
end)
|
||||
end
|
||||
|
|
|
@ -77,8 +77,7 @@ include("numeric.jl")
|
|||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
the sign of the gradient applied to `x`."""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import Base: *, ==
|
||||
import Base: *
|
||||
|
||||
import LinearAlgebra
|
||||
using Statistics
|
||||
|
@ -48,7 +48,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
|
|||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
|
@ -60,9 +60,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
x::TrackedArray == y = data(x) == y
|
||||
y == x::TrackedArray = y == data(x)
|
||||
x::TrackedArray == y::TrackedArray = data(x) == data(y)
|
||||
for op in [:(==), :≈]
|
||||
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
|
@ -86,9 +88,9 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
|||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
S = size(xs)
|
||||
|
@ -286,15 +288,6 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
|
|||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||
|
||||
# @grad function (a::AbstractMatrix * b::AbstractVecOrMat)
|
||||
# # @show size(a) size(b)
|
||||
# data(a)*data(b), function (Δ)
|
||||
# @show size(Δ) size(b) size(Δ*transpose(b)) size(Δ*transpose(data(b)))
|
||||
# @show typeof(Δ) typeof(b)
|
||||
# (Δ * transpose(b), transpose(a) * Δ)
|
||||
# end
|
||||
# end
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
|
@ -336,48 +329,85 @@ end
|
|||
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
_size(x::AbstractArray) = size(x)
|
||||
_size(x) = ()
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
unbroadcast(x::Tuple, Δ) =
|
||||
x == size(Δ) ? Δ :
|
||||
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
|
||||
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
|
||||
|
||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
||||
dual(x, p) = x
|
||||
dual(x::Real, p) = Dual(x, p)
|
||||
|
||||
function getpartial(Δ, x, i)
|
||||
@inbounds p = getindex(partials(x), i)
|
||||
return Δ * p
|
||||
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
||||
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
||||
return Δ * f(dargs...).partials[1]
|
||||
end
|
||||
|
||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||
sizes = _size.(args)
|
||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N)))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
y = value.(out)
|
||||
back = function (Δ_)
|
||||
Δ = data(Δ_)
|
||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N))
|
||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
||||
nobacksies(:broadcast, dxs)
|
||||
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
||||
y = broadcast(f, data.(args)...)
|
||||
eltype(y) <: Real || return y
|
||||
eltype(y) == Bool && return y
|
||||
function back(Δ)
|
||||
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
|
||||
dxs = unbroadcast.(args, Δargs)
|
||||
return nobacksies(:broadcast, dxs)
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, tracker.(args)), y)
|
||||
end
|
||||
|
||||
using Base.Broadcast: BroadcastStyle
|
||||
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||
|
||||
struct TrackedStyle <: BroadcastStyle end
|
||||
|
||||
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||
|
||||
function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
|
||||
bc = Broadcast.flatten(bc)
|
||||
∇broadcast(bc.f, bc.args...)
|
||||
# We have to re-build the original broadcast struct to get the appropriate array
|
||||
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
||||
broadcast_rebuild(xs) = data(xs)
|
||||
|
||||
broadcast_rebuild(bc::Broadcasted) =
|
||||
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
||||
|
||||
preprocess(x) = x
|
||||
|
||||
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||
bc1 = Broadcast.flatten(bc)
|
||||
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
||||
∇broadcast(bc2.f, bc1.args...)
|
||||
end
|
||||
|
||||
using Requires
|
||||
|
||||
# https://github.com/FluxML/Flux.jl/issues/353
|
||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||
isflat(bc) && return bc
|
||||
args = cat_nested(bc)
|
||||
let makeargs = make_makeargs(bc), f = bc.f
|
||||
newf = @inline function(args::Vararg{Any,N}) where N
|
||||
f(makeargs(args...)...)
|
||||
end
|
||||
return Broadcasted{Style}(newf, args, bc.axes)
|
||||
end
|
||||
end
|
||||
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||
bc = t[1]
|
||||
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||
let makeargs = make_makeargs(makeargs, bc.args)
|
||||
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||
return @inline function(args::Vararg{Any,N}) where N
|
||||
args1 = makeargs(args...)
|
||||
a, b = headargs(args1...), tailargs(args1...)
|
||||
(f(a...), b...)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -70,7 +70,7 @@ struct Params
|
|||
Params(xs) = new(IdSet(xs))
|
||||
end
|
||||
|
||||
@forward Params.params Base.start, Base.next, Base.done
|
||||
@forward Params.params Base.iterate, Base.length
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
|
@ -86,6 +86,8 @@ Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
|||
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
@ -94,7 +96,6 @@ function Base.getindex(g::Grads, x)
|
|||
g[tracker(x)]
|
||||
end
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
|
@ -136,7 +137,7 @@ end
|
|||
function forward(f, args...)
|
||||
args = param.(args)
|
||||
y, back = forward(() -> f(args...), Params(args))
|
||||
y, Δ -> getindex.(back(Δ), args)
|
||||
y, Δ -> getindex.(Ref(back(Δ)), args)
|
||||
end
|
||||
|
||||
function losscheck(x)
|
||||
|
|
|
@ -11,7 +11,7 @@ Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
|||
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||
|
||||
(::Type{IdSet{T}})(xs) where T = push!(IdSet{T}(), xs...)
|
||||
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
|
||||
|
||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||
|
||||
|
@ -20,6 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
|||
|
||||
@forward IdSet.dict Base.length
|
||||
|
||||
Base.start(s::IdSet) = start(keys(s.dict))
|
||||
Base.next(s::IdSet, st) = next(keys(s.dict), st)
|
||||
Base.done(s::IdSet, st) = done(keys(s.dict), st)
|
||||
function Base.iterate(v::IdSet, state...)
|
||||
y = Base.iterate(keys(v.dict), state...)
|
||||
y === nothing && return nothing
|
||||
return (y[1], y[2])
|
||||
end
|
||||
|
|
|
@ -30,8 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
|
|||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||
for op in [:(==), :≈, :<]
|
||||
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
||||
end
|
||||
|
||||
Base.eps(x::TrackedReal) = eps(data(x))
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux: gpu
|
||||
|
||||
info("Testing Flux/GPU")
|
||||
@info "Testing GPU Support"
|
||||
|
||||
@testset "CuArrays" begin
|
||||
|
||||
|
@ -14,6 +14,7 @@ cx = gpu(x)
|
|||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = gpu(x)
|
||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
@test (cx .+ 1) isa CuArray
|
||||
|
||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||
cm = gpu(m)
|
||||
|
@ -25,10 +26,13 @@ x = [1,2,3]
|
|||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
|
||||
# Fails in Pkg.test ffs
|
||||
# c = gpu(Conv((2,2),3=>4))
|
||||
# l = c(gpu(rand(10,10,3,2)))
|
||||
# Flux.back!(sum(l))
|
||||
xs = param(rand(5,5))
|
||||
ys = Flux.onehotbatch(1:5,1:5)
|
||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
|
||||
c = gpu(Conv((2,2),3=>4))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
Flux.back!(sum(l))
|
||||
|
||||
end
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux.Tracker: TrackedArray, data
|
||||
|
||||
@info "Testing Flux CUDNN"
|
||||
|
||||
@testset "CUDNN BatchNorm" begin
|
||||
x = TrackedArray(rand(10, 10, 3, 1))
|
||||
m = BatchNorm(3)
|
||||
|
|
|
@ -6,3 +6,6 @@ using Test
|
|||
@test length(CMUDict.phones()) == 39
|
||||
|
||||
@test length(CMUDict.symbols()) == 84
|
||||
|
||||
@test MNIST.images()[1] isa Matrix
|
||||
@test MNIST.labels() isa Vector{Int64}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
using Flux, Test
|
||||
using Flux: maxpool, meanpool
|
||||
|
||||
@testset "Pooling" begin
|
||||
x = randn(10, 10, 3, 2)
|
||||
mp = MaxPool((2, 2))
|
||||
@test mp(x) == maxpool(x, (2,2))
|
||||
mp = MeanPool((2, 2))
|
||||
@test mp(x) == meanpool(x, (2,2))
|
||||
end
|
||||
|
||||
@testset "CNN" begin
|
||||
r = zeros(28, 28, 1, 5)
|
||||
m = Chain(
|
||||
Conv((2, 2), 1=>16, relu),
|
||||
MaxPool((2,2)),
|
||||
Conv((2, 2), 16=>8, relu),
|
||||
MaxPool((2,2)),
|
||||
x -> reshape(x, :, size(x, 4)),
|
||||
Dense(288, 10), softmax)
|
||||
|
||||
@test size(m(r)) == (10, 5)
|
||||
end
|
|
@ -55,17 +55,17 @@ end
|
|||
# .1 * 4 + 0 = .4
|
||||
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||
|
||||
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# 2×1 Array{Float64,2}:
|
||||
# 1.14495
|
||||
# 1.14495
|
||||
@test m.σ² ≈ 0.1 .* var(x.data, 2, corrected=false)*3/2 + 0.9 .* [1., 1.]
|
||||
@test m.σ² ≈ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-6)
|
||||
x′ = m(x).data
|
||||
@test x′[1] ≈ (1 .- 0.3) / 1.1449489742783179
|
||||
end
|
||||
|
||||
# with activation function
|
||||
|
|
|
@ -42,8 +42,8 @@ const ϵ = 1e-7
|
|||
|
||||
logŷ, y = randn(3), rand(3)
|
||||
@testset "binarycrossentropy" begin
|
||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||
end
|
||||
|
||||
@testset "logitbinarycrossentropy" begin
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
using Flux:onecold
|
||||
using Test
|
||||
|
||||
@testset "onecold" begin
|
||||
a = [1, 2, 5, 3.]
|
||||
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
|
||||
labels = ['A', 'B', 'C', 'D']
|
||||
|
||||
@test onecold(a) == 3
|
||||
@test onecold(A) == [3, 1, 4]
|
||||
@test onecold(a, labels) == 'C'
|
||||
@test onecold(A, labels) == ['C', 'A', 'D']
|
||||
end
|
|
@ -1,6 +1,6 @@
|
|||
using Flux.Optimise
|
||||
using Flux.Tracker
|
||||
|
||||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||
|
@ -23,7 +23,7 @@ end
|
|||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
Iterators.repeated((), 100),
|
||||
()->(),
|
||||
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
||||
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
||||
|
||||
@test 3 < i < 50
|
||||
end
|
||||
|
|
|
@ -1,18 +1,47 @@
|
|||
using Flux, Test, Random
|
||||
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
|
||||
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
|
||||
if Base.JLOptions().check_bounds == 1
|
||||
file = @__FILE__
|
||||
run(```
|
||||
$(Base.julia_cmd())
|
||||
--color=$(Base.have_color ? "yes" : "no")
|
||||
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
|
||||
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
|
||||
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
|
||||
$(file)
|
||||
```)
|
||||
exit()
|
||||
end
|
||||
|
||||
srand(0)
|
||||
using Flux, Test, Random
|
||||
using Random
|
||||
|
||||
Random.seed!(0)
|
||||
|
||||
# So we can use the system CuArrays
|
||||
insert!(LOAD_PATH, 2, "@v#.#")
|
||||
|
||||
@testset "Flux" begin
|
||||
|
||||
@info "Testing Basics"
|
||||
|
||||
include("utils.jl")
|
||||
include("tracker.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("onehot.jl")
|
||||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
|
||||
# if Base.find_in_path("CuArrays") ≠ nothing
|
||||
# include("cuda/cuda.jl")
|
||||
# end
|
||||
@info "Testing Layers"
|
||||
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
|
||||
@info "Running Gradient Checks"
|
||||
|
||||
include("tracker.jl")
|
||||
|
||||
if Base.find_package("CuArrays") != nothing
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
|
||||
end
|
||||
|
|
|
@ -3,23 +3,20 @@ using Flux.Tracker, Test, NNlib
|
|||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||
@testset "Tracker" begin
|
||||
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
|
@ -36,7 +33,6 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
function promotiontest(f, A, B, C)
|
||||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
|
@ -69,6 +65,7 @@ end
|
|||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||
end
|
||||
|
||||
|
||||
@testset for hcatf in [hcat, cat2]
|
||||
@test gradtest(hcatf, rand(5), rand(5))
|
||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||
|
@ -97,7 +94,7 @@ end
|
|||
|
||||
@test !isa(vcat(rand(2)), TrackedArray)
|
||||
@test !isa(hcat(rand(2)), TrackedArray)
|
||||
@test !isa(cat(1,rand(2)), TrackedArray)
|
||||
@test !isa(cat(rand(2), dims=1), TrackedArray)
|
||||
|
||||
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
|
@ -115,10 +112,12 @@ end
|
|||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||
|
||||
|
@ -128,7 +127,7 @@ end
|
|||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@test gradtest(diagm, rand(3))
|
||||
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
@ -183,9 +182,30 @@ end
|
|||
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
||||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
@testset "equality & order" begin
|
||||
# TrackedReal
|
||||
@test param(2)^2 == param(4)
|
||||
@test param(2)^2 == 4
|
||||
@test 4 == param(2)^2
|
||||
|
||||
@test param(2)^2 == 4.0
|
||||
@test param(2)^2 ≈ param(4)
|
||||
@test param(2)^2 ≈ 4
|
||||
@test 4 ≈ param(2)^2
|
||||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
@test (param([1,2,3]) .<= 2) == [true, true, false]
|
||||
@test (2 .> param([1,2,3])) == [true, false, false]
|
||||
@test (2 .>= param([1,2,3])) == [true, true, false]
|
||||
|
||||
# TrackedArray
|
||||
@test param([1,2,3]).^2 == param([1,4,9])
|
||||
@test [1,2,3].^2 == param([1,4,9])
|
||||
@test param([1,2,3]).^2 == [1,4,9]
|
||||
|
||||
@test param([1,2,3]).^2 ≈ param([1,4,9])
|
||||
@test [1,2,3].^2 ≈ param([1,4,9])
|
||||
@test param([1,2,3]).^2 ≈ [1,4,9]
|
||||
end
|
||||
|
||||
@testset "reshape" begin
|
||||
x = reshape(param(rand(2,2,2)), 4, 2)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
||||
using Flux
|
||||
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal
|
||||
using StatsBase: std
|
||||
using Dates
|
||||
using Random
|
||||
using Test
|
||||
|
||||
@testset "Throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
||||
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
||||
f()
|
||||
f()
|
||||
f()
|
||||
|
@ -15,7 +17,7 @@ using Dates
|
|||
|
||||
@testset "leading behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
||||
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
||||
f()
|
||||
@test length(a) == 1
|
||||
f()
|
||||
|
@ -27,7 +29,7 @@ using Dates
|
|||
|
||||
@testset "trailing behaviour" begin
|
||||
a = []
|
||||
f = throttle(()->push!(a, now()), 1, leading=false, trailing=true)
|
||||
f = throttle(()->push!(a, time()), 1, leading=false, trailing=true)
|
||||
f()
|
||||
@test length(a) == 0
|
||||
f()
|
||||
|
@ -61,7 +63,7 @@ end
|
|||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
srand(0)
|
||||
Random.seed!(0)
|
||||
# initn() should yield a kernel with stddev ~= 1e-2
|
||||
v = initn(10, 10)
|
||||
@test std(v) > 0.9*1e-2
|
||||
|
|
Loading…
Reference in New Issue