This commit is contained in:
Avik Pal 2018-10-01 09:50:30 +05:30
commit f3e39a1e55
45 changed files with 718 additions and 429 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ docs/build/
docs/site/ docs/site/
docs/flux.css docs/flux.css
deps deps
Manifest.toml

View File

@ -4,11 +4,16 @@ os:
- linux - linux
# - osx # - osx
julia: julia:
- 0.6 - 0.7
- 1.0
- nightly
# uncomment the following lines to override the default test script # uncomment the following lines to override the default test script
script: # script:
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi # - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)' # - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
matrix:
allow_failures:
- julia: nightly
after_success: after_success:
- julia -e 'Pkg.add("Documenter")' - julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")'
- julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))' - julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'

View File

@ -1,14 +1,15 @@
julia 0.6.0 julia 0.7
Juno Juno
MacroTools 0.3.3 MacroTools 0.3.3
NNlib NNlib
Requires Requires
Adapt Adapt
GZip CodecZlib
Colors Colors
ZipFile ZipFile
AbstractTrees AbstractTrees
Reexport Reexport
StatsBase
# AD # AD
ForwardDiff 0.5.0 ForwardDiff 0.5.0

View File

@ -26,6 +26,6 @@ deploydocs(
repo = "github.com/FluxML/Flux.jl.git", repo = "github.com/FluxML/Flux.jl.git",
target = "build", target = "build",
osname = "linux", osname = "linux",
julia = "0.6", julia = "1.0",
deps = nothing, deps = nothing,
make = nothing) make = nothing)

View File

@ -3,7 +3,7 @@
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy. It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
``` ```
julia> using Flux: onehot julia> using Flux: onehot, onecold
julia> onehot(:b, [:a, :b, :c]) julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector: 3-element Flux.OneHotVector:
@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
true true
``` ```
The inverse is `argmax` (which can take a general probability distribution, as well as just booleans). The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
```julia ```julia
julia> argmax(ans, [:a, :b, :c]) julia> onecold(ans, [:a, :b, :c])
:c :c
julia> argmax([true, false, false], [:a, :b, :c]) julia> onecold([true, false, false], [:a, :b, :c])
:a :a
julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c]) julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c :c
``` ```
## Batches ## Batches
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches. `onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
```julia ```julia
julia> using Flux: onehotbatch julia> using Flux: onehotbatch

View File

@ -1,18 +1,17 @@
# Flux: The Julia Machine Learning Library # Flux: The Julia Machine Learning Library
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. The whole stack is implemented in clean Julia code (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)) and any part can be tweaked to your liking. Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles:
# Installation * **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work and be fast.
* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When in doubt, its well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own.
* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models.
Install [Julia 0.6.0 or later](https://julialang.org/downloads/), if you haven't already. ## Installation
```julia Download [Julia 1.0](https://julialang.org/) or later, if you haven't already. You can add Flux from using Julia's package manager, by typing `] add Flux` in the Julia prompt.
Pkg.add("Flux")
# Optional but recommended
Pkg.update() # Keep your packages up to date
Pkg.test("Flux") # Check things installed correctly
```
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models. If you have CUDA you can also run `] add CuArrays` to get GPU support; see [here](gpu.md) for more details.
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs. ## Learning Flux
There are several different ways to learn Flux. If you just want to get started writing models, the [model zoo](https://github.com/FluxML/model-zoo/) gives good starting points for many common ones. This documentation provides a reference to all of Flux's APIs, as well as a from-scratch introduction to Flux's take on models and how they work. Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts.

View File

@ -134,7 +134,7 @@ All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around
```julia ```julia
julia> x.tracker julia> x.tracker
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0]) Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
``` ```
The `Tracker` stores the gradient of a given object, which we've seen before. The `Tracker` stores the gradient of a given object, which we've seen before.

View File

@ -10,14 +10,14 @@ using Flux.Tracker
f(x) = 3x^2 + 2x + 1 f(x) = 3x^2 + 2x + 1
# df/dx = 6x + 2 # df/dx = 6x + 2
f(x) = Tracker.gradient(f, x)[1] df(x) = Tracker.gradient(f, x)[1]
f(2) # 14.0 (tracked) df(2) # 14.0 (tracked)
# d²f/dx² = 6 # d²f/dx² = 6
f(x) = Tracker.gradient(f, x)[1] d2f(x) = Tracker.gradient(df, x)[1]
f(2) # 6.0 (tracked) d2f(2) # 6.0 (tracked)
``` ```
(We'll learn more about why these numbers show up as `(tracked)` below.) (We'll learn more about why these numbers show up as `(tracked)` below.)
@ -172,7 +172,7 @@ using Flux
layers = [Dense(10, 5, σ), Dense(5, 2), softmax] layers = [Dense(10, 5, σ), Dense(5, 2), softmax]
model(x) = foldl((x, m) -> m(x), x, layers) model(x) = foldl((x, m) -> m(x), layers, init = x)
model(rand(10)) # => 2-element vector model(rand(10)) # => 2-element vector
``` ```
@ -211,7 +211,7 @@ m(5) # => 26
Flux provides a set of helpers for custom layers, which you can enable by calling Flux provides a set of helpers for custom layers, which you can enable by calling
```julia ```julia
Flux.treelike(Affine) Flux.@treelike Affine
``` ```
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

View File

@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks.
Chain Chain
Dense Dense
Conv Conv
MaxPool
MeanPool
``` ```
## Recurrent Layers ## Recurrent Layers

View File

@ -1,7 +1,7 @@
# Regularisation # Regularisation
Applying regularisation to model parameters is straightforward. We just need to Applying regularisation to model parameters is straightforward. We just need to
apply an appropriate regulariser, such as `vecnorm`, to each model parameter and apply an appropriate regulariser, such as `norm`, to each model parameter and
add the result to the overall loss. add the result to the overall loss.
For example, say we have a simple regression. For example, say we have a simple regression.
@ -15,12 +15,12 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`. We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
```julia ```julia
penalty() = vecnorm(m.W) + vecnorm(m.b) penalty() = norm(m.W) + norm(m.b)
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty() loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
``` ```
When working with layers, Flux provides the `params` function to grab all When working with layers, Flux provides the `params` function to grab all
parameters at once. We can easily penalise everything with `sum(vecnorm, params)`. parameters at once. We can easily penalise everything with `sum(norm, params)`.
```julia ```julia
julia> params(m) julia> params(m)
@ -28,7 +28,7 @@ julia> params(m)
param([0.355408 0.533092; … 0.430459 0.171498]) param([0.355408 0.533092; … 0.430459 0.171498])
param([0.0, 0.0, 0.0, 0.0, 0.0]) param([0.0, 0.0, 0.0, 0.0, 0.0])
julia> sum(vecnorm, params(m)) julia> sum(norm, params(m))
26.01749952921026 (tracked) 26.01749952921026 (tracked)
``` ```
@ -40,7 +40,7 @@ m = Chain(
Dense(128, 32, relu), Dense(128, 32, relu),
Dense(32, 10), softmax) Dense(32, 10), softmax)
loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m)) loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m))
loss(rand(28^2), rand(10)) loss(rand(28^2), rand(10))
``` ```
@ -57,6 +57,6 @@ julia> activations(c, rand(10))
param([0.0330606, -0.456104]) param([0.0330606, -0.456104])
param([0.61991, 0.38009]) param([0.61991, 0.38009])
julia> sum(vecnorm, ans) julia> sum(norm, ans)
2.639678767773633 (tracked) 2.639678767773633 (tracked)
``` ```

View File

@ -1,18 +1,15 @@
__precompile__()
module Flux module Flux
# Zero Flux Given # Zero Flux Given
using Juno, Requires, Reexport using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
Dropout, LayerNorm, BatchNorm, Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu params, mapleaves, cpu, gpu
@reexport using NNlib @reexport using NNlib
using NNlib: @fix
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
using .Tracker using .Tracker
@ -37,6 +34,6 @@ include("layers/normalise.jl")
include("data/Data.jl") include("data/Data.jl")
@require CuArrays include("cuda/cuda.jl") @init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
end # module end # module

View File

@ -1,6 +1,6 @@
module CUDA module CUDA
using CuArrays using ..CuArrays
CuArrays.cudnn_available() && include("cudnn.jl") CuArrays.cudnn_available() && include("cudnn.jl")

View File

@ -1,24 +1,27 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
cudnnDataType, TensorDesc, FilterDesc cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
mutable struct DropoutDesc mutable struct DropoutDesc
ptr::Ptr{Void} ptr::Ptr{Nothing}
states::CuVector{UInt8} states::CuVector{UInt8}
end end
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
function DropoutDesc(ρ::Real; seed::Integer=0) function DropoutDesc(ρ::Real; seed::Integer=0)
d = [C_NULL] d = [C_NULL]
s = Csize_t[0] s = Csize_t[0]
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d) @check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s) @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s)
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0? states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states) desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong), @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
desc,libcudnn_handle[],ρ,states,length(states),seed) desc,libcudnn_handle[],ρ,states,length(states),seed)
finalizer(desc, x -> finalizer(desc) do x
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
return desc return desc
end end
@ -43,10 +46,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output] # LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1) function params(w::CuVector, input, hidden, n = 1)
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape) slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
wx = slice(0, (input, hidden*n)) wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n)) wh = slice(length(wx), (hidden, hidden*n))
bias = w[length(wx)+length(wh) + (1:hidden*n)] bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
(wx, wh), bias (wx, wh), bias
end end
@ -57,14 +60,14 @@ mutable struct RNNDesc{T}
params::CuVector{T} params::CuVector{T}
weights::NTuple{2,CuMatrix{T}} weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T} bias::CuVector{T}
ptr::Ptr{Void} ptr::Ptr{Nothing}
end end
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
function rnnParamSize(T, r, input) function rnnParamSize(T, r, input)
size = Csize_t[0] size = Csize_t[0]
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint), @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
return Int(size[])÷sizeof(T) return Int(size[])÷sizeof(T)
end end
@ -74,26 +77,27 @@ ngates(r::RNNDesc) = ngates(r.mode)
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
d = [C_NULL] d = [C_NULL]
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d) @check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
dropoutDesc = DropoutDesc(0) dropoutDesc = DropoutDesc(0)
inputMode = LINEAR_INPUT inputMode = LINEAR_INPUT
direction = UNIDIRECTIONAL direction = UNIDIRECTIONAL
algo = RNN_ALGO_STANDARD algo = RNN_ALGO_STANDARD
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint), @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzeros(T, rnnParamSize(T, d[], input)) w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here # TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd, x -> finalizer(rd) do x
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
return rd return rd
end end
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0] size = Csize_t[0]
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}), @check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size) libcudnn_handle[], r, seqlen, xdesc, size)
return Int(size[]) return Int(size[])
end end
@ -110,7 +114,7 @@ getworkspace(r::RNNDesc, seqlen, xdesc) =
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
size = Csize_t[0] size = Csize_t[0]
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}), @check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
libcudnn_handle[], r, seqlen, xdesc, size) libcudnn_handle[], r, seqlen, xdesc, size)
return Int(size[]) return Int(size[])
end end
@ -119,19 +123,19 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
workspace, reserve=nothing) where T workspace, reserve=nothing) where T
if reserve == nothing if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Void}, Csize_t), Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen, libcudnn_handle[], rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace)) workspace, length(workspace))
else else
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t), Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen, libcudnn_handle[], rnn, seqlen,
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, length(workspace), reserve, length(reserve)) workspace, length(workspace), reserve, length(reserve))
@ -140,7 +144,7 @@ end
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
hDesc(h::Void) = C_NULL, C_NULL hDesc(h::Nothing) = C_NULL, C_NULL
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
function hDesc(h::CuArray) function hDesc(h::CuArray)
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
@ -187,18 +191,18 @@ forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint, (Ptr{Nothing}, Ptr{Nothing}, Cint,
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t), Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end end
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way? # Same as above, any more efficient way?
dy = dy_ isa Integer ? zeros(y) : dy_ dy = dy_ isa Integer ? zero(y) : dy_
yd = xDesc(y) yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h) dh = similar(h)
@ -217,19 +221,19 @@ backwardData(rnn, y, dy, dho, hx, reserve) =
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
workspace, reserve) where T workspace, reserve) where T
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
Ptr{Ptr{Void}}, Ptr{T}, #x Ptr{Ptr{Nothing}}, Ptr{T}, #x
Ptr{Void}, Ptr{T}, #hx Ptr{Nothing}, Ptr{T}, #hx
Ptr{Ptr{Void}}, Ptr{T}, #y Ptr{Ptr{Nothing}}, Ptr{T}, #y
Ptr{Void}, Csize_t, #ws Ptr{Nothing}, Csize_t, #ws
Ptr{Void}, Ptr{T}, #dw Ptr{Nothing}, Ptr{T}, #dw
Ptr{Void}, Csize_t), #rs Ptr{Nothing}, Csize_t), #rs
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y, libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
workspace, length(workspace), dwd, dw, reserve, length(reserve)) workspace, length(workspace), dwd, dw, reserve, length(reserve))
end end
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zeros(rnn.params) dw = zero(rnn.params)
cudnnRNNBackwardWeights(rnn, 1, cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y, xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw, FilterDesc(T, (1, 1, length(dw))), dw,
@ -241,17 +245,17 @@ end
import ..Flux: Flux, relu import ..Flux: Flux, relu
import ..Tracker: TrackedArray import ..Tracker: TrackedArray
using CUDAnative using .CuArrays.CUDAnative
using CuArrays: @cuindex, cudims using .CuArrays: @cuindex, cudims
function copy_transpose!(dst::CuArray, src::CuArray) function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src) function kernel(dst, src)
I = @cuindex dst I = @cuindex dst
dst[I...] = src[reverse(I)...] dst[I...] = src[reverse(I)...]
return return
end end
blk, thr = cudims(dst) blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src) @cuda blocks=blk threads=thr kernel(dst, src)
return dst return dst
end end
@ -324,7 +328,7 @@ end
h_ = hBatch(x, data(h)) h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db)) nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
end end
end end
@ -339,6 +343,6 @@ end
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
dWi.', dWh.', db)) transpose(dWi), transpose(dWh), db))
end end
end end

View File

@ -11,6 +11,8 @@ function __init__()
end end
include("mnist.jl") include("mnist.jl")
export MNIST
include("cmudict.jl") include("cmudict.jl")
using .CMUDict using .CMUDict

View File

@ -14,7 +14,7 @@ function load()
return return
end end
end end
info("Downloading CMUDict dataset") @info "Downloading CMUDict dataset"
mkpath(deps("cmudict")) mkpath(deps("cmudict"))
for x in suffixes for x in suffixes
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
@ -24,25 +24,25 @@ end
function phones() function phones()
load() load()
Symbol.(first.(split.(split(readstring(deps("cmudict", "cmudict.phones")), Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
"\n", keep = false), "\t"))) "\n", keepempty = false), "\t")))
end end
function symbols() function symbols()
load() load()
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")), Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
"\n", keep = false)) "\n", keepempty = false))
end end
function rawdict() function rawdict()
load() load()
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n")))) filter(!isempty, split.(split(read(deps("cmudict", "cmudict"),String), "\n"))))
end end
validword(s) = ismatch(r"^[\w\-\.]+$", s) validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
cmudict() = filter((s, ps) -> validword(s), rawdict()) cmudict() = filter(p -> validword(p.first), rawdict())
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.'] alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']

View File

@ -1,11 +1,17 @@
module MNIST module MNIST
using GZip, Colors using CodecZlib, Colors
const Gray = Colors.Gray{Colors.N0f8} const Gray = Colors.Gray{Colors.N0f8}
const dir = joinpath(@__DIR__, "../../deps/mnist") const dir = joinpath(@__DIR__, "../../deps/mnist")
function gzopen(f, file)
open(file) do io
f(GzipDecompressorStream(io))
end
end
function load() function load()
mkpath(dir) mkpath(dir)
cd(dir) do cd(dir) do
@ -14,10 +20,10 @@ function load()
"t10k-images-idx3-ubyte", "t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte"] "t10k-labels-idx1-ubyte"]
isfile(file) && continue isfile(file) && continue
info("Downloading MNIST dataset") @info "Downloading MNIST dataset"
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz") download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
open(file, "w") do io open(file, "w") do io
write(io, GZip.open(read, "$file.gz")) write(io, gzopen(read, "$file.gz"))
end end
end end
end end
@ -49,7 +55,7 @@ function labelheader(io::IO)
end end
function rawimage(io::IO) function rawimage(io::IO)
img = Array{Gray}(NCOLS, NROWS) img = Array{Gray}(undef, NCOLS, NROWS)
for i in 1:NCOLS, j in 1:NROWS for i in 1:NCOLS, j in 1:NROWS
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8)) img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
end end

View File

@ -4,8 +4,8 @@ using ZipFile
using ..Data: deps using ..Data: deps
function load() function load()
isfile(deps("sentiment.zip")) || return isfile(deps("sentiment.zip")) && return
info("Downloading sentiment treebank dataset") @info "Downloading sentiment treebank dataset"
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
deps("sentiment.zip")) deps("sentiment.zip"))
end end
@ -14,7 +14,7 @@ getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
function getfile(name) function getfile(name)
r = ZipFile.Reader(deps("sentiment.zip")) r = ZipFile.Reader(deps("sentiment.zip"))
text = readstring(getfile(r, "trees/$name")) text = read(getfile(r, "trees/$name"), String)
close(r) close(r)
return text return text
end end
@ -26,15 +26,16 @@ totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b))
totree(t::Expr) = totree_(t.args...) totree(t::Expr) = totree_(t.args...)
function parsetree(s) function parsetree(s)
s = replace(s, r"\$", s -> "\\\$") s = replace(s, "\\" => "")
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"") s = replace(s, "\$" => "\\\$")
s = replace(s, " ", ", ") s = replace(s, r"[^ \n\(\)]+" => s -> "\"$s\"")
return totree(parse(s)) s = replace(s, " " => ", ")
return totree(Meta.parse(s))
end end
function gettrees(name) function gettrees(name)
load() load()
ss = split(getfile("$name.txt"), '\n', keep = false) ss = split(getfile("$name.txt"), '\n', keepempty = false)
return parsetree.(ss) return parsetree.(ss)
end end

View File

@ -16,19 +16,19 @@ m(x) == m[2](m[1](x))
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`. `Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers. `m[1:3](x)` will calculate the output of the first three layers.
""" """
type Chain struct Chain
layers::Vector{Any} layers::Vector{Any}
Chain(xs...) = new([xs...]) Chain(xs...) = new([xs...])
end end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! @forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
@forward Chain.layers Base.start, Base.next, Base.done @forward Chain.layers Base.iterate
children(c::Chain) = c.layers children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...) mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...) adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers) (c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
@ -38,10 +38,7 @@ function Base.show(io::IO, c::Chain)
print(io, ")") print(io, ")")
end end
# Seem to need this for `accumulate`; try removing on 0.7 activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
""" """
Dense(in::Integer, out::Integer, σ = identity) Dense(in::Integer, out::Integer, σ = identity)
@ -76,11 +73,11 @@ function Dense(in::Integer, out::Integer, σ = identity;
return Dense(param(initW(out, in)), param(initb(out)), σ) return Dense(param(initW(out, in)), param(initb(out)), σ)
end end
treelike(Dense) @treelike Dense
function (a::Dense)(x) function (a::Dense)(x)
W, b, σ = a.W, a.b, a.σ W, b, σ = a.W, a.b, a.σ
@fix σ.(W*x .+ b) σ.(W*x .+ b)
end end
function Base.show(io::IO, l::Dense) function Base.show(io::IO, l::Dense)
@ -107,7 +104,7 @@ end
Diagonal(in::Integer; initα = ones, initβ = zeros) = Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(param(initα(in)), param(initβ(in))) Diagonal(param(initα(in)), param(initβ(in)))
treelike(Diagonal) @treelike Diagonal
function (a::Diagonal)(x) function (a::Diagonal)(x)
α, β = a.α, a.β α, β = a.α, a.β

View File

@ -1,6 +1,6 @@
using NNlib: conv using NNlib: conv
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)}) @generated sub2(::Val{N}) where N = :(Val($(N-2)))
expand(N, i::Tuple) = i expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N) expand(N, i::Integer) = ntuple(_ -> i, N)
@ -28,14 +28,14 @@ end
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} = stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N = stride = 1, pad = 0, dilation = 1) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)
Flux.treelike(Conv) @treelike Conv
function (c::Conv)(x) function (c::Conv)(x)
# TODO: breaks gpu broadcast :( # TODO: breaks gpu broadcast :(
@ -50,3 +50,48 @@ function Base.show(io::IO, l::Conv)
l.σ == identity || print(io, ", ", l.σ) l.σ == identity || print(io, ", ", l.σ)
print(io, ")") print(io, ")")
end end
"""
MaxPool(k)
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
"""
struct MaxPool{N}
k::NTuple{N,Int}
pad::NTuple{N,Int}
stride::NTuple{N,Int}
end
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end
"""
MeanPool(k)
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
Takes the keyword arguments `pad` and `stride`.
"""
struct MeanPool{N}
k::NTuple{N,Int}
pad::NTuple{N,Int}
stride::NTuple{N,Int}
end
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
end

View File

@ -58,7 +58,7 @@ end
LayerNorm(h::Integer) = LayerNorm(h::Integer) =
LayerNorm(Diagonal(h)) LayerNorm(Diagonal(h))
treelike(LayerNorm) @treelike LayerNorm
(a::LayerNorm)(x) = a.diag(normalise(x)) (a::LayerNorm)(x) = a.diag(normalise(x))
@ -108,7 +108,7 @@ mutable struct BatchNorm{F,V,W,N}
end end
BatchNorm(chs::Integer, λ = identity; BatchNorm(chs::Integer, λ = identity;
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) = initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true) zeros(chs), ones(chs), ϵ, momentum, true)
@ -130,13 +130,13 @@ function (BN::BatchNorm)(x)
ϵ = data(convert(T, BN.ϵ)) ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, axes) μ = mean(x, dims = axes)
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ) σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
# update moving mean/std # update moving mean/std
mtm = data(convert(T, BN.momentum)) mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1) BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
end end
let λ = BN.λ let λ = BN.λ

View File

@ -1,4 +1,4 @@
gate(h, n) = (1:h) + h*(n-1) gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = x[gate(h,n)] gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
@ -38,7 +38,7 @@ function (m::Recur)(xs...)
return y return y
end end
treelike(Recur, (:cell, :init)) @treelike Recur cell, init
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
@ -84,7 +84,7 @@ end
RNNCell(in::Integer, out::Integer, σ = tanh; RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) = init = glorot_uniform) =
RNNCell(σ, param(init(out, in)), param(init(out, out)), RNNCell(σ, param(init(out, in)), param(init(out, out)),
param(zeros(out)), param(initn(out))) param(zeros(out)), param(init(out)))
function (m::RNNCell)(h, x) function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
@ -94,7 +94,7 @@ end
hidden(m::RNNCell) = m.h hidden(m::RNNCell) = m.h
treelike(RNNCell) @treelike RNNCell
function Base.show(io::IO, l::RNNCell) function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
@ -123,13 +123,12 @@ end
function LSTMCell(in::Integer, out::Integer; function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform) init = glorot_uniform)
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)), cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
param(initn(out)), param(initn(out))) param(init(out)), param(init(out)))
cell.b.data[gate(out, 2)] = 1 cell.b.data[gate(out, 2)] .= 1
return cell return cell
end end
function (m::LSTMCell)(h_, x) function (m::LSTMCell)((h, c), x)
h, c = h_ # TODO: nicer syntax on 0.7
b, o = m.b, size(h, 1) b, o = m.b, size(h, 1)
g = m.Wi*x .+ m.Wh*h .+ b g = m.Wi*x .+ m.Wh*h .+ b
input = σ.(gate(g, o, 1)) input = σ.(gate(g, o, 1))
@ -143,7 +142,7 @@ end
hidden(m::LSTMCell) = (m.h, m.c) hidden(m::LSTMCell) = (m.h, m.c)
treelike(LSTMCell) @treelike LSTMCell
Base.show(io::IO, l::LSTMCell) = Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
@ -170,7 +169,7 @@ end
GRUCell(in, out; init = glorot_uniform) = GRUCell(in, out; init = glorot_uniform) =
GRUCell(param(init(out*3, in)), param(init(out*3, out)), GRUCell(param(init(out*3, in)), param(init(out*3, out)),
param(zeros(out*3)), param(initn(out))) param(zeros(out*3)), param(init(out)))
function (m::GRUCell)(h, x) function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1) b, o = m.b, size(h, 1)
@ -178,13 +177,13 @@ function (m::GRUCell)(h, x)
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h = (1.-z).* .+ z.*h h = (1 .- z).* .+ z.*h
return h, h return h, h
end end
hidden(m::GRUCell) = m.h hidden(m::GRUCell) = m.h
treelike(GRUCell) @treelike GRUCell
Base.show(io::IO, l::GRUCell) = Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")

View File

@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
mse(, y) = sum(( .- y).^2)/length(y) mse(, y) = sum(( .- y).^2)/length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
@fix -sum(y .* log.() .* weight) / size(y, 2) -sum(y .* log.() .* weight) / size(y, 2)
end end
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)
@ -47,7 +47,7 @@ logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
Normalise each column of `x` to mean 0 and standard deviation 1. Normalise each column of `x` to mean 0 and standard deviation 1.
""" """
function normalise(x::AbstractVecOrMat) function normalise(x::AbstractVecOrMat)
μ′ = mean(x, 1) μ′ = mean(x, dims = 1)
σ = std(x, 1, mean = μ′) σ = std(x, dims = 1, mean = μ′)
return (x .- μ′) ./ σ return (x .- μ′) ./ σ
end end

View File

@ -32,20 +32,21 @@ import Adapt.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
@require CuArrays begin @init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
import CuArrays: CuArray, cudaconvert import .CuArrays: CuArray, cudaconvert
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
end end
function onehot(l, labels) function onehot(l, labels)
i = findfirst(labels, l) i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels") i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels)) OneHotVector(i, length(labels))
end end
function onehot(l, labels, unk) function onehot(l, labels, unk)
i = findfirst(labels, l) i = something(findfirst(isequal(l), labels), 0)
i > 0 || return onehot(unk, labels) i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels)) OneHotVector(i, length(labels))
end end
@ -53,11 +54,15 @@ end
onehotbatch(ls, labels, unk...) = onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) = onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
labels[findfirst(y, maximum(y))]
argmax(y::AbstractMatrix, l...) = onecold(y::AbstractMatrix, labels...) =
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1) dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
function argmax(xs...)
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
return onecold(xs...)
end
# Ambiguity hack # Ambiguity hack

View File

@ -2,14 +2,14 @@ module Optimise
export train!, export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
struct Param{T} struct Param{T}
x::T x::T
Δ::T Δ::T
end end
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x)) Param(x::AbstractArray) = Param(x, zero(x))
include("optimisers.jl") include("optimisers.jl")
include("interface.jl") include("interface.jl")
@ -17,6 +17,7 @@ include("train.jl")
using Flux.Tracker: TrackedArray using Flux.Tracker: TrackedArray
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad) Param(x::TrackedArray) = Param(x.data, x.grad)
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end end

View File

@ -14,7 +14,7 @@ function descentweightdecay(p::Param, η::Real, γ::Real)
end end
function momentum(p::Param, ρ, η) function momentum(p::Param, ρ, η)
v = zeros(p.x) v = zero(p.x)
function () function ()
@. v = ρ * v - η * p.Δ @. v = ρ * v - η * p.Δ
@. p.Δ = -v @. p.Δ = -v
@ -23,7 +23,7 @@ end
# Ref. https://arxiv.org/pdf/1212.0901.pdf # Ref. https://arxiv.org/pdf/1212.0901.pdf
function nesterov(p::Param, ρ, η) function nesterov(p::Param, ρ, η)
v = zeros(p.x) v = zero(p.x)
function () function ()
d = @. ρ^2 * v - (1+ρ) * η * p.Δ d = @. ρ^2 * v - (1+ρ) * η * p.Δ
@. v = ρ*v - η*p.Δ @. v = ρ*v - η*p.Δ
@ -32,7 +32,7 @@ function nesterov(p::Param, ρ, η)
end end
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) acc = zero(p.x)
function () function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 @. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ) @. p.Δ *= η / (acc + ϵ)
@ -40,7 +40,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
end end
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zero(p.x) .+ ϵ
function () function ()
@. acc += p.Δ^2 @. acc += p.Δ^2
@. p.Δ *= η / (acc + ϵ) @. p.Δ *= η / (acc + ϵ)
@ -48,8 +48,8 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
end end
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8) function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) acc = zero(p.x)
Δacc = zeros(p.x) Δacc = zero(p.x)
function () function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 @. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= (Δacc + ϵ) / (acc + ϵ) @. p.Δ *= (Δacc + ϵ) / (acc + ϵ)
@ -58,8 +58,8 @@ function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
end end
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x) mt = zero(p.x)
vt = zeros(p.x) vt = zero(p.x)
β1p, β2p = β1, β2 β1p, β2p = β1, β2
function () function ()
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β1 * mt + (1 - β1) * p.Δ
@ -71,8 +71,8 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
end end
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x) mt = zero(p.x)
ut = zeros(p.x) ut = zero(p.x)
β1p = β1 β1p = β1
function () function ()
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β1 * mt + (1 - β1) * p.Δ
@ -83,9 +83,9 @@ function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999,
end end
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x) mt = zero(p.x)
vt = zeros(p.x) .+ ϵ vt = zero(p.x) .+ ϵ
v̂t = zeros(p.x) .+ ϵ v̂t = zero(p.x) .+ ϵ
function () function ()
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2 @. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
@ -95,8 +95,8 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
end end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x) mt = zero(p.x)
vt = zeros(p.x) vt = zero(p.x)
β1p, β2p = β1, β2 β1p, β2p = β1, β2
function () function ()
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β1 * mt + (1 - β1) * p.Δ

View File

@ -1,5 +1,6 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: back!
import Base.depwarn
runall(f) = f runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -14,6 +15,25 @@ macro interrupts(ex)
end) end)
end end
struct StopException <: Exception end
"""
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This would trigger the train loop to stop and exit.
```julia
# Example callback:
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```
"""
function stop()
throw(StopException())
end
""" """
train!(loss, data, opt) train!(loss, data, opt)
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
l = loss(d...) try
@interrupts back!(l) l = loss(d...)
opt() @interrupts back!(l)
cb() == :stop && break opt()
if cb() == :stop
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
break
end
catch ex
if ex isa StopException
break
else
rethrow(ex)
end
end
end end
end end
@ -59,7 +90,7 @@ hello
""" """
macro epochs(n, ex) macro epochs(n, ex)
:(@progress for i = 1:$(esc(n)) :(@progress for i = 1:$(esc(n))
info("Epoch $i") @info "Epoch $i"
$(esc(ex)) $(esc(ex))
end) end)
end end

View File

@ -12,7 +12,7 @@ tracker(x) = nothing
istracked(x) = tracker(x) nothing istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x)) isleaf(x) = !istracked(x) || isleaf(tracker(x))
grad(x) = grad(tracker(x)) grad(x) = grad(tracker(x))
grad(::Void) = nothing grad(::Nothing) = nothing
data(x) = x data(x) = x
struct Call{F,As<:Tuple} struct Call{F,As<:Tuple}
@ -20,7 +20,7 @@ struct Call{F,As<:Tuple}
args::As args::As
end end
Call(f, args) = Call{typeof(f),typeof(args)}(f, args) Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
Call() = Call(nothing, ()) Call() = Call(nothing, ())
# When deserialising, the object_id changes # When deserialising, the object_id changes
@ -35,7 +35,7 @@ mutable struct Tracked{T}
grad::T grad::T
Tracked{T}(f::Call) where T = new(0, f, false) Tracked{T}(f::Call) where T = new(0, f, false)
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad) Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
end end
istracked(x::Tracked) = true istracked(x::Tracked) = true
@ -46,7 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end function _forward end
function track(f, xs...; kw...) function track(f::F, xs...; kw...) where F
y, back = _forward(f, xs...; kw...) y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y) track(Call(back, tracker.(xs)), y)
end end
@ -77,10 +77,9 @@ include("numeric.jl")
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse `f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`. the sign of the gradient applied to `x`."""
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = x, Δ -> (nothing, f(Δ)) @grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
""" """
checkpoint(f, args...) checkpoint(f, args...)

View File

@ -1,3 +1,11 @@
import Base: *
import LinearAlgebra
import LinearAlgebra: inv, \, /
using Statistics
using LinearAlgebra: Transpose, Adjoint, diagm, diag
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
tracker::Tracked{A} tracker::Tracked{A}
data::A data::A
@ -21,24 +29,20 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ) TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x)) TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}") print(io, "TrackedArray{…,$A}")
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true) function Base.summary(io::IO, x::TrackedArray)
if repr print(io, "Tracked ")
print(io, "param(") summary(io, data(x))
Base.showarray(io, data(X), true)
print(io, ")")
else
header && print(io, "Tracked ")
Base.showarray(io, data(X), false, header = header)
end
end end
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
Base.setindex!(xs::TrackedArray, v, i...) = Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`") error("Can't differentiate `setindex!`")
@ -46,7 +50,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
# Fallthrough methods # Fallthrough methods
for f in :[Base.size, Base.ndims].args for f in :[Base.size, Base.ndims, Base.collect].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end end
@ -58,9 +62,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
Base.:(==)(x::TrackedArray, y) = data(x) == y for op in [:(==), :≈]
Base.:(==)(y, x::TrackedArray) = y == data(x) @eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y) @eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
end
# Array Stdlib # Array Stdlib
@ -79,37 +85,20 @@ Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,) @grad -(xs) = -data(xs), Δ -> (-Δ,)
Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),) @grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) @grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
@grad function repmat(xs, m, n = 1) @grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
repmat(data(xs), m, n), function (Δ)
Δ′ = similar(xs)
S = size(xs)
for (i,v) in enumerate(data(Δ))
d1 = divrem(i-1, S[1]*m)
x = d1[2] % S[1]+1
y = d1[1] % S[2]+1
Δ′[x, y] += v
end
return (nobacksies(:repmat, Δ′), nothing, nothing)
end
end
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
repeat(data(xs), inner = inner, outer = outer), function (Δ) repeat(data(xs), inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs) Δ′ = zero(xs)
S = size(xs) S = size(xs)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ)) for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S. # wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
@ -119,8 +108,8 @@ Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
end end
end end
for f in [:vcat, :hcat] for f in [:vcat, :hcat]
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
@eval begin @eval begin
# This section is a bit of a hack since julia doesn't have a standardised # This section is a bit of a hack since julia doesn't have a standardised
# promotion mechanism for concatenation yet # promotion mechanism for concatenation yet
@ -129,18 +118,18 @@ for f in [:vcat, :hcat]
# It should support tracked concatenation with rank ∈ (1,2) with a # It should support tracked concatenation with rank ∈ (1,2) with a
# TrackedArray anywhere among the arguments This works as long as base has # TrackedArray anywhere among the arguments This works as long as base has
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. # other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...) Base.$f(a::$UArray...) = track($f, a...)
# It should support tracked concatenation with rank>2 if the TrackedArray is # It should support tracked concatenation with rank>2 if the TrackedArray is
# first # first
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
# It should support tracked concatenation with rank>2 if the TrackedArray is # It should support tracked concatenation with rank>2 if the TrackedArray is
# second # second
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
c::Union{TrackedArray,Vector,RowVector,Matrix}...) = c::$UArray...) =
track($f, a, b, c...) # resolves ambiguity introduced by previous row track($f, a, b, c...) # resolves ambiguity introduced by previous row
end end
end end
@ -175,21 +164,23 @@ end
end end
end end
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
@grad function cat(dims, Xs...) @grad function cat(Xs...; dims)
cat(dims, data.(Xs)...), function (Δ) cat(data.(Xs)..., dims = dims), function (Δ)
start = ntuple(i -> 0, Val{ndims(Δ)}) start = ntuple(i -> 0, Val(ndims(Δ)))
Δs = [begin Δs = [begin
dim_xs = 1:ndims(xs) dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)}) till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)}) xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs)) d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs start = start .+ till_xs
d d
end for xs in Xs] end for xs in Xs]
return (nothing, Δs...,) return (Δs...,)
end end
end end
@ -216,100 +207,123 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b) Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
inv(A::TrackedArray) = Tracker.track(inv, A)
@grad function inv(A)
return inv(Tracker.data(A)), function (Δ)
Ainv = inv(A)
∇A = - Ainv' * Δ * Ainv'
return (∇A, )
end
end
# (/) rdivide
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
@grad function (A / B)
return Tracker.data(A) / Tracker.data(B), function (Δ)
Binv = inv(B)
∇B = - Binv' * A' * Δ * Binv'
return (Δ * Binv', ∇B)
end
end
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
@grad function (A \ B)
return Tracker.data(A) \ Tracker.data(B), function (Δ)
Ainv = inv(A)
∇A = - Ainv' * Δ * B' * Ainv'
return (∇A, Ainv' * Δ)
end
end
# Reductions # Reductions
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim) Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
Base.sum(xs::TrackedArray) = track(sum, xs)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
@grad sum(xs, dim...) = sum(data(xs), dim...), @grad sum(xs; dims = :) = sum(data(xs), dims = dims),
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...) Δ -> (zero(xs) .+ Δ, )
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs) Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,) @grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
@grad prod(xs, dim) = prod(data(xs), dim), @grad prod(xs, dim) = prod(data(xs), dims = dim),
Δ -> (nobacksies(:sum, Δ -> (nobacksies(:sum,
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ), reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
nothing) nothing)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = track(mean, xs) Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
Base.maximum(xs::TrackedArray) = track(maximum, xs) Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region) Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
Base.minimum(xs::TrackedArray) = track(minimum, xs)
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) import LinearAlgebra: dot
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs) @grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
# Hacks to get std working # Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) = Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) _std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) = _std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
Base.vecnorm(x::TrackedArray, p::Real = 2) = LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),) @grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing) _backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
@grad function maximum(xs, r...) @grad function maximum(xs; dims = dims)
maximum(data(xs), r...), function (Δ) maximum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs) Δ′ = zero(xs)
_, i = findmax(data(xs), r...) _, i = findmax(data(xs), dims = dims)
Δ′[i] = data(Δ) Δ′[i] = data(Δ)
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...) return (nobacksies(:maximum, Δ′),)
end end
end end
@grad function minimum(xs, r...)
minimum(data(xs), r...), function (Δ) @grad function minimum(xs; dims = dims)
minimum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs) Δ′ = zero(xs)
_, i = findmin(data(xs), r...) _, i = findmin(data(xs), dims = dims)
Δ′[i] = data(Δ) Δ′[i] = data(Δ)
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...) return (nobacksies(:minimum, Δ′),)
end end
end end
# BLAS # BLAS
Base.diagm(x::TrackedVector) = track(diagm, x) LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) @grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
@eval begin x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
import Base.$f x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b) x::TrackedMatrix * y::AbstractVector = track(*, x, y)
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b) x::AbstractMatrix * y::TrackedVector = track(*, x, y)
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b) x::TrackedMatrix * y::TrackedVector = track(*, x, y)
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b) x::TrackedVector * y::AbstractVector = track(*, x, y)
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b) x::AbstractVector * y::TrackedVector = track(*, x, y)
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b) x::TrackedVector * y::TrackedVector = track(*, x, y)
end
end
@grad a::AbstractMatrix * b::AbstractVecOrMat = @grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ)) data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
# NNlib # NNlib
@ -352,45 +366,85 @@ end
using ForwardDiff: Dual, partials, value using ForwardDiff: Dual, partials, value
dualify(xs, n) = xs trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
dualify(xs::Real, ps) = Dual(xs, ps)
unbroadcast(x::Tuple, Δ) = unbroadcast(x::AbstractArray, Δ) =
x == size(Δ) ? Δ : size(x) == size(Δ) ? Δ :
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x) length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
unbroadcast(x::Tuple{}, Δ) = sum(Δ) unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
function getpartial(Δ, x, i) dual(x, p) = x
@inbounds p = getindex(partials(x), i) dual(x::Real, p) = Dual(x, p)
return Δ * p
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
return Δ * f(dargs...).partials[1]
end end
function ∇broadcast(f, args::Vararg{Any,N}) where N @inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
sizes = size.(args) y = broadcast(f, data.(args)...)
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) eltype(y) <: Real || return y
out = broadcast(f, dargs...) eltype(y) == Bool && return y
eltype(out) <: Dual || return out function back(Δ)
y = value.(out) Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
back = function (Δ_) dxs = map(unbroadcast, args, Δargs)
Δ = data(Δ_) return dxs
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
nobacksies(:broadcast, dxs)
end end
# So we can return non-tracked arrays # So we can return non-tracked arrays
track(Call(back, tracker.(args)), y) track(Call(back, tracker.(args)), y)
end end
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...) struct TrackedStyle <: BroadcastStyle end
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
# We have to re-build the original broadcast struct to get the appropriate array
# style. We need this primarily to support CuArrays' broadcasting fixes.
broadcast_rebuild(xs) = data(xs)
broadcast_rebuild(bc::Broadcasted) =
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
preprocess(x) = x
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
bc1 = Broadcast.flatten(bc)
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
∇broadcast(bc2.f, bc1.args...)
end
using Requires
# https://github.com/FluxML/Flux.jl/issues/353
@init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
let makeargs = make_makeargs(bc), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
end
end
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
bc = t[1]
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
let makeargs = make_makeargs(makeargs, bc.args)
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs(args...)
a, b = headargs(args1...), tailargs(args1...)
(f(a...), b...)
end
end
end
end
end

View File

@ -26,7 +26,7 @@ function back_(c::Call, Δ)
foreach(back, c.args, data.(Δs)) foreach(back, c.args, data.(Δs))
end end
back_(::Call{Void}, Δ) = nothing back_(::Call{Nothing}, Δ) = nothing
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
@ -47,7 +47,7 @@ function back(x::Tracked, Δ)
return return
end end
back(::Void, _) = return back(::Nothing, _) = return
# Interface methods # Interface methods
@ -70,7 +70,7 @@ struct Params
Params(xs) = new(IdSet(xs)) Params(xs) = new(IdSet(xs))
end end
@forward Params.params Base.start, Base.next, Base.done @forward Params.params Base.iterate, Base.length
function Base.show(io::IO, ps::Params) function Base.show(io::IO, ps::Params)
print(io, "Params([") print(io, "Params([")
@ -79,14 +79,16 @@ function Base.show(io::IO, ps::Params)
end end
struct Grads struct Grads
grads::ObjectIdDict grads::IdDict{Any,Any}
end end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(ObjectIdDict()) Grads() = Grads(IdDict())
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps)) @forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x] Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x) function Base.getindex(g::Grads, x)
@ -94,9 +96,8 @@ function Base.getindex(g::Grads, x)
g[tracker(x)] g[tracker(x)]
end end
@forward Grads.grads Base.setindex!, Base.haskey
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ) function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ) Δs = c.func(Δ)
@ -105,7 +106,7 @@ function back_(g::Grads, c::Call, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end end
back_(g::Grads, ::Call{Void}, Δ) = nothing back_(g::Grads, ::Call{Nothing}, Δ) = nothing
function back(g::Grads, x::Tracked, Δ) function back(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return) x.isleaf && (accum!(g, x, Δ); return)
@ -119,7 +120,7 @@ function back(g::Grads, x::Tracked, Δ)
return return
end end
back(::Grads, ::Void, _) = return back(::Grads, ::Nothing, _) = return
function forward(f, ps::Params) function forward(f, ps::Params)
y = f() y = f()
@ -136,7 +137,7 @@ end
function forward(f, args...) function forward(f, args...)
args = param.(args) args = param.(args)
y, back = forward(() -> f(args...), Params(args)) y, back = forward(() -> f(args...), Params(args))
y, Δ -> getindex.(back(Δ), args) y, Δ -> getindex.(Ref(back(Δ)), args)
end end
function losscheck(x) function losscheck(x)
@ -152,3 +153,13 @@ function gradient(f, args...)
end end
derivative(f, x) = gradient(f, x)[1] derivative(f, x) = gradient(f, x)[1]
# Non-nesting versions
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
grad.(xs)
end

View File

@ -1,17 +1,17 @@
struct IdSet{T} <: AbstractSet{T} struct IdSet{T} <: AbstractSet{T}
dict::ObjectIdDict dict::IdDict{T,Nothing}
IdSet{T}() where T = new(ObjectIdDict()) IdSet{T}() where T = new(IdDict{T,Nothing}())
end end
Base.eltype{T}(::IdSet{T}) = T Base.eltype(::IdSet{T}) where T = T
IdSet() = IdSet{Any}() IdSet() = IdSet{Any}()
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s) Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s) Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x) Base.in(x, s::IdSet) = haskey(s.dict, x)
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...) IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
IdSet(xs) = IdSet{eltype(xs)}(xs) IdSet(xs) = IdSet{eltype(xs)}(xs)
@ -20,6 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}()
@forward IdSet.dict Base.length @forward IdSet.dict Base.length
Base.start(s::IdSet) = start(keys(s.dict)) function Base.iterate(v::IdSet, state...)
Base.next(s::IdSet, st) = next(keys(s.dict), st) y = Base.iterate(keys(v.dict), state...)
Base.done(s::IdSet, st) = done(keys(s.dict), st) y === nothing && return nothing
return (y[1], y[2])
end

View File

@ -1,5 +1,5 @@
function ngradient(f, xs::AbstractArray...) function ngradient(f, xs::AbstractArray...)
grads = zeros.(xs) grads = zero.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x) for (x, Δ) in zip(xs, grads), i in 1:length(x)
δ = sqrt(eps()) δ = sqrt(eps())
tmp = x[i] tmp = x[i]

View File

@ -30,8 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T") error("Not implemented: convert tracked $S to tracked $T")
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) for op in [:(==), :≈, :<]
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) @eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end
Base.eps(x::TrackedReal) = eps(data(x)) Base.eps(x::TrackedReal) = eps(data(x))
@ -115,3 +118,7 @@ end
function back_(c::Call{typeof(collect)}, Δ) function back_(c::Call{typeof(collect)}, Δ)
foreach(back, c.args[1], data(Δ)) foreach(back, c.args[1], data(Δ))
end end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
end

View File

@ -7,16 +7,27 @@ mapchildren(f, x) = x
children(x::Tuple) = x children(x::Tuple) = x
mapchildren(f, x::Tuple) = map(f, x) mapchildren(f, x::Tuple) = map(f, x)
function treelike(T, fs = fieldnames(T)) function treelike(m::Module, T, fs = fieldnames(T))
@eval current_module() begin @eval m begin
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),) Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...) Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
end end
end end
function treelike(T, fs = fieldnames(T))
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
treelike(Base._current_module(), T, fs)
end
macro treelike(T, fs = nothing)
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
end
isleaf(x) = isempty(children(x)) isleaf(x) = isempty(children(x))
function mapleaves(f, x; cache = ObjectIdDict()) function mapleaves(f, x; cache = IdDict())
haskey(cache, x) && return cache[x] haskey(cache, x) && return cache[x]
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x) cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end end
@ -43,7 +54,7 @@ function loadparams!(m, xs)
for (p, x) in zip(params(m), xs) for (p, x) in zip(params(m), xs)
size(p) == size(x) || size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))") error("Expected param size $(size(p)), got $(size(x))")
copy!(data(p), data(x)) copyto!(data(p), data(x))
end end
end end
@ -53,7 +64,7 @@ cpu(m) = mapleaves(x -> adapt(Array, x), m)
gpu_adaptor = identity gpu_adaptor = identity
@require CuArrays begin @init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
global gpu_adaptor = CuArrays.cu global gpu_adaptor = CuArrays.cu
end end

View File

@ -1,8 +1,8 @@
# Arrays # Arrays
initn(dims...) = randn(dims...)/100 initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims))) glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims))) glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
@ -119,7 +119,7 @@ function throttle(f, timeout; leading=true, trailing=false)
end end
cooldown = false cooldown = false
@schedule try @async try
while (sleep(timeout); later != nothing) while (sleep(timeout); later != nothing)
later() later()
later = nothing later = nothing
@ -145,7 +145,7 @@ function jacobian(m,x)
y = m(xp) y = m(xp)
k = length(y) k = length(y)
n = length(x) n = length(x)
J = Matrix{eltype(x)}(n,k) J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator Flux.back!(y[i]) # Populate gradient accumulator
J[:,i] = xp.grad J[:,i] = xp.grad

View File

@ -1,7 +1,7 @@
using Flux, Flux.Tracker, CuArrays, Base.Test using Flux, Flux.Tracker, CuArrays, Test
using Flux: gpu using Flux: gpu
info("Testing Flux/GPU") @info "Testing GPU Support"
@testset "CuArrays" begin @testset "CuArrays" begin
@ -14,6 +14,7 @@ cx = gpu(x)
x = Flux.onehotbatch([1, 2, 3], 1:3) x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x) cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray @test cx isa Flux.OneHotMatrix && cx.data isa CuArray
@test (cx .+ 1) isa CuArray
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
cm = gpu(m) cm = gpu(m)
@ -25,10 +26,13 @@ x = [1,2,3]
cx = gpu(x) cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx) @test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
# Fails in Pkg.test ffs xs = param(rand(5,5))
# c = gpu(Conv((2,2),3=>4)) ys = Flux.onehotbatch(1:5,1:5)
# l = c(gpu(rand(10,10,3,2))) @test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)
# Flux.back!(sum(l))
c = gpu(Conv((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))
end end

View File

@ -1,6 +1,6 @@
using Flux, CuArrays, Base.Test using Flux, CuArrays, Test
info("Testing Flux/CUDNN") @info "Testing Flux/CUDNN"
@testset "RNN" begin @testset "RNN" begin
@testset for R in [RNN, GRU, LSTM] @testset for R in [RNN, GRU, LSTM]

View File

@ -1,8 +1,13 @@
using Flux.Data using Flux.Data
using Base.Test using Test
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args @test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
@test length(CMUDict.phones()) == 39 @test length(CMUDict.phones()) == 39
@test length(CMUDict.symbols()) == 84 @test length(CMUDict.symbols()) == 84
@test MNIST.images()[1] isa Matrix
@test MNIST.labels() isa Vector{Int64}
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}

23
test/layers/conv.jl Normal file
View File

@ -0,0 +1,23 @@
using Flux, Test
using Flux: maxpool, meanpool
@testset "Pooling" begin
x = randn(10, 10, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, (2,2))
end
@testset "CNN" begin
r = zeros(28, 28, 1, 5)
m = Chain(
Conv((2, 2), 1=>16, relu),
MaxPool((2,2)),
Conv((2, 2), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax)
@test size(m(r)) == (10, 5)
end

View File

@ -4,7 +4,7 @@ using Flux: testmode!
x = [1.,2.,3.] x = [1.,2.,3.]
@test x == testmode!(Dropout(0.1))(x) @test x == testmode!(Dropout(0.1))(x)
@test x == Dropout(0)(x) @test x == Dropout(0)(x)
@test zeros(x) == Dropout(1)(x) @test zero(x) == Dropout(1)(x)
x = rand(100) x = rand(100)
m = Dropout(0.9) m = Dropout(0.9)
@ -53,17 +53,17 @@ end
# .1 * 4 + 0 = .4 # .1 * 4 + 0 = .4
@test m.μ reshape([0.3, 0.4], 2, 1) @test m.μ reshape([0.3, 0.4], 2, 1)
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] # julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 2×1 Array{Float64,2}: # 2×1 Array{Float64,2}:
# 1.14495 # 1.14495
# 1.14495 # 1.14495
@test m.σ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] @test m.σ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
testmode!(m) testmode!(m)
@test !m.active @test !m.active
x = m(x).data x = m(x).data
@test x[1] (1 - 0.3) / 1.1449489742783179 @test x[1] (1 .- 0.3) / 1.1449489742783179
end end
# with activation function # with activation function

View File

@ -1,4 +1,4 @@
using Base.Test using Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy σ, binarycrossentropy, logitbinarycrossentropy
@ -42,8 +42,8 @@ const ϵ = 1e-7
logŷ, y = randn(3), rand(3) logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin @testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ))) @test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))
end end
@testset "logitbinarycrossentropy" begin @testset "logitbinarycrossentropy" begin

13
test/onehot.jl Normal file
View File

@ -0,0 +1,13 @@
using Flux:onecold
using Test
@testset "onecold" begin
a = [1, 2, 5, 3.]
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
labels = ['A', 'B', 'C', 'D']
@test onecold(a) == 3
@test onecold(A) == [3, 1, 4]
@test onecold(a, labels) == 'C'
@test onecold(A, labels) == ['C', 'A', 'D']
end

View File

@ -1,6 +1,6 @@
using Flux.Optimise using Flux.Optimise
using Flux.Tracker using Flux.Tracker
using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
@ -23,7 +23,7 @@ end
Flux.train!(() -> (sleep(0.1); i += 1; l), Flux.train!(() -> (sleep(0.1); i += 1; l),
Iterators.repeated((), 100), Iterators.repeated((), 100),
()->(), ()->(),
cb = Flux.throttle(() -> (i > 3 && :stop), 1)) cb = Flux.throttle(() -> (i > 3 && stop()), 1))
@test 3 < i < 50 @test 3 < i < 50
end end

View File

@ -1,17 +1,46 @@
using Flux, Base.Test # Pkg.test runs with --check_bounds=1, forcing all bounds checks.
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
if Base.JLOptions().check_bounds == 1
file = @__FILE__
run(```
$(Base.julia_cmd())
--color=$(Base.have_color ? "yes" : "no")
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
$(file)
```)
exit()
end
srand(0) using Flux, Test, Random
using Random
Random.seed!(0)
# So we can use the system CuArrays
insert!(LOAD_PATH, 2, "@v#.#")
@testset "Flux" begin @testset "Flux" begin
@info "Testing Basics"
include("utils.jl") include("utils.jl")
include("tracker.jl") include("onehot.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl") include("optimise.jl")
include("data.jl") include("data.jl")
if Base.find_in_path("CuArrays") nothing @info "Testing Layers"
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/conv.jl")
@info "Running Gradient Checks"
include("tracker.jl")
if Base.find_package("CuArrays") != nothing
include("cuda/cuda.jl") include("cuda/cuda.jl")
end end

View File

@ -1,22 +1,27 @@
using Flux.Tracker, Base.Test, NNlib using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv using NNlib: conv
using Printf: @sprintf
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
using Statistics: mean, std
using Random
# using StatsBase
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...) gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "Tracker" begin @testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) @test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5)) @test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
@test gradtest(x -> sum(x, (2, 3)), (3,4,5)) @test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
@test gradtest(x -> prod(x, (2, 3)), (3,4,5)) @test gradtest(x -> sum(x), randn(Float64,2,3))
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5)) @test gradtest(x -> prod(x), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), 3)
@ -28,7 +33,6 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5)) @test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
@test gradtest(x -> x', rand(5)) @test gradtest(x -> x', rand(5))
function promotiontest(f, A, B, C) function promotiontest(f, A, B, C)
r0 = f(A, B, C) r0 = f(A, B, C)
r1 = f(param(A), B, C) r1 = f(param(A), B, C)
@ -48,8 +52,8 @@ function promotiontest(f, A, B, C)
end end
@testset "concat" begin @testset "concat" begin
cat1(x...) = cat(1, x...) cat1(x...) = cat(x..., dims = 1)
cat2(x...) = cat(2, x...) cat2(x...) = cat(x..., dims = 2)
@testset for vcatf in [vcat, cat1] @testset for vcatf in [vcat, cat1]
@test gradtest(vcatf, rand(5), rand(3)) @test gradtest(vcatf, rand(5), rand(3))
@ -61,6 +65,7 @@ end
@test gradtest(vcatf, rand(5)', rand(2,5)) @test gradtest(vcatf, rand(5)', rand(2,5))
end end
@testset for hcatf in [hcat, cat2] @testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5)) @test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)') @test gradtest(hcatf, rand(5)', rand(5)')
@ -71,17 +76,17 @@ end
@test gradtest(hcatf, rand(5), rand(5,2)) @test gradtest(hcatf, rand(5), rand(5,2))
end end
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] @testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
@test gradtest(catf, rand(5)) @test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)') @test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5)) @test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3)) @test gradtest(catf, rand(2,5,3))
end end
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) @test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 3:5 @testset "cat($dim, ...)" for dim in 3:5
catdim = (x...) -> cat(dim, x...) catdim = (x...) -> cat(x..., dims = dim)
@test gradtest(catdim, rand(5), rand(5), rand(5)) @test gradtest(catdim, rand(5), rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3)) @test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
@ -89,12 +94,12 @@ end
@test !isa(vcat(rand(2)), TrackedArray) @test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray) @test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(1,rand(2)), TrackedArray) @test !isa(cat(rand(2), dims=1), TrackedArray)
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) @test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin @testset "promotiontest" begin
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] @testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
promotiontest(fcat, rand(2), rand(2), rand(2)) promotiontest(fcat, rand(2), rand(2), rand(2))
promotiontest(fcat, rand(2)', rand(2)', rand(2)') promotiontest(fcat, rand(2)', rand(2)', rand(2)')
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2)) promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
@ -105,16 +110,14 @@ end
promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2)) promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end end
end end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
# TODO unreliable @test gradtest(x -> repeat(x; inner=2), rand(5))
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5)) @test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) @test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@ -124,54 +127,59 @@ end
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(diagm, rand(3)) @test gradtest(f-> Matrix(Diagonal(f)), rand(3))
@test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5))
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
@testset "mean" begin @testset "mean" begin
@test gradtest(mean, rand(2, 3)) @test gradtest(mean, rand(2, 3))
@test gradtest(x -> mean(x, 1), rand(2, 3)) @test gradtest(x -> mean(x, dims=1), rand(2, 3))
@test gradtest(x -> mean(x, 2), rand(2, 3)) @test gradtest(x -> mean(x, dims=2), rand(2, 3))
@test gradtest(x -> mean(x, 3), rand(2, 3, 4)) @test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) @test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
end end
@testset "maximum" begin @testset "maximum" begin
@test gradtest(maximum, rand(2, 3)) @test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, 1), rand(2, 3)) @test gradtest(x -> maximum(x, dims=1), rand(2, 3))
@test gradtest(x -> maximum(x, 2), rand(2, 3)) @test gradtest(x -> maximum(x, dims=2), rand(2, 3))
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4)) @test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4)) @test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
end end
@testset "minimum" begin @testset "minimum" begin
@test gradtest(minimum, rand(2, 3)) @test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, 1), rand(2, 3)) @test gradtest(x -> minimum(x, dims=1), rand(2, 3))
@test gradtest(x -> minimum(x, 2), rand(2, 3)) @test gradtest(x -> minimum(x, dims=2), rand(2, 3))
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4)) @test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4)) @test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
end end
@test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5)) @test gradtest(x -> std(x, dims = 1), rand(5,5))
@test gradtest((x, y) -> x .* y, rand(5), rand(5)) @test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5)) @test gradtest(dot, rand(5), rand(5))
@test gradtest(vecnorm, rand(5)) @test gradtest(norm, rand(5))
@test gradtest(rand(5)) do x @test gradtest(rand(5)) do x
y = x.^2 y = x.^2
2y + x 2y + x
end end
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2)) @test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
@ -179,9 +187,30 @@ end
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) @test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
@test (param([1,2,3]) .< 2) == [true, false, false] @testset "equality & order" begin
# TrackedReal
@test param(2)^2 == param(4)
@test param(2)^2 == 4
@test 4 == param(2)^2
@test param(2)^2 == 4.0 @test param(2)^2 param(4)
@test param(2)^2 4
@test 4 param(2)^2
@test (param([1,2,3]) .< 2) == [true, false, false]
@test (param([1,2,3]) .<= 2) == [true, true, false]
@test (2 .> param([1,2,3])) == [true, false, false]
@test (2 .>= param([1,2,3])) == [true, true, false]
# TrackedArray
@test param([1,2,3]).^2 == param([1,4,9])
@test [1,2,3].^2 == param([1,4,9])
@test param([1,2,3]).^2 == [1,4,9]
@test param([1,2,3]).^2 param([1,4,9])
@test [1,2,3].^2 param([1,4,9])
@test param([1,2,3]).^2 [1,4,9]
end
@testset "reshape" begin @testset "reshape" begin
x = reshape(param(rand(2,2,2)), 4, 2) x = reshape(param(rand(2,2,2)), 4, 2)
@ -211,14 +240,11 @@ end
@testset "Fallbacks" begin @testset "Fallbacks" begin
xs = param([1 2; 3 4]) xs = param([1 2; 3 4])
@test similar(xs) isa Matrix{Float64} @test similar(xs) isa Matrix{Float64}
# Remove this test if we do LowerTriangular properly
L = LowerTriangular(xs)
@test L*L' isa Matrix{TrackedReal{Float64}}
end end
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00" @test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4)) @inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
b = param(rand()) b = param(rand())
Tracker.back!(b) Tracker.back!(b)
@ -231,6 +257,11 @@ Tracker.back!(b)
z = xy[1]*xy[2] z = xy[1]*xy[2]
back!(z) back!(z)
@test grad.((x,y)) == (3, 2) @test grad.((x,y)) == (3, 2)
@test Tracker.gradient(2, 3) do x, y
xy = Tracker.collect([x, y])
xy[1]*xy[2]
end == (3, 2)
end end
# Gradient Hooks # Gradient Hooks

View File

@ -1,9 +1,13 @@
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian using Flux
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal
using StatsBase: std
using Random
using Test
@testset "Throttle" begin @testset "Throttle" begin
@testset "default behaviour" begin @testset "default behaviour" begin
a = [] a = []
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false) f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
f() f()
f() f()
f() f()
@ -13,7 +17,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
@testset "leading behaviour" begin @testset "leading behaviour" begin
a = [] a = []
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false) f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
f() f()
@test length(a) == 1 @test length(a) == 1
f() f()
@ -25,7 +29,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
@testset "trailing behaviour" begin @testset "trailing behaviour" begin
a = [] a = []
f = throttle(()->push!(a, now()), 1, leading=false, trailing=true) f = throttle(()->push!(a, time()), 1, leading=false, trailing=true)
f() f()
@test length(a) == 0 @test length(a) == 0
f() f()
@ -59,7 +63,7 @@ end
@testset "Initialization" begin @testset "Initialization" begin
# Set random seed so that these tests don't fail randomly # Set random seed so that these tests don't fail randomly
srand(0) Random.seed!(0)
# initn() should yield a kernel with stddev ~= 1e-2 # initn() should yield a kernel with stddev ~= 1e-2
v = initn(10, 10) v = initn(10, 10)
@test std(v) > 0.9*1e-2 @test std(v) > 0.9*1e-2