Merge branch 'master' of https://github.com/FluxML/Flux.jl
This commit is contained in:
commit
f3e39a1e55
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,3 +5,4 @@ docs/build/
|
|||||||
docs/site/
|
docs/site/
|
||||||
docs/flux.css
|
docs/flux.css
|
||||||
deps
|
deps
|
||||||
|
Manifest.toml
|
||||||
|
17
.travis.yml
17
.travis.yml
@ -4,11 +4,16 @@ os:
|
|||||||
- linux
|
- linux
|
||||||
# - osx
|
# - osx
|
||||||
julia:
|
julia:
|
||||||
- 0.6
|
- 0.7
|
||||||
|
- 1.0
|
||||||
|
- nightly
|
||||||
# uncomment the following lines to override the default test script
|
# uncomment the following lines to override the default test script
|
||||||
script:
|
# script:
|
||||||
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
|
||||||
- julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
# - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)'
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- julia: nightly
|
||||||
after_success:
|
after_success:
|
||||||
- julia -e 'Pkg.add("Documenter")'
|
- julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")'
|
||||||
- julia -e 'cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||||
|
5
REQUIRE
5
REQUIRE
@ -1,14 +1,15 @@
|
|||||||
julia 0.6.0
|
julia 0.7
|
||||||
Juno
|
Juno
|
||||||
MacroTools 0.3.3
|
MacroTools 0.3.3
|
||||||
NNlib
|
NNlib
|
||||||
Requires
|
Requires
|
||||||
Adapt
|
Adapt
|
||||||
GZip
|
CodecZlib
|
||||||
Colors
|
Colors
|
||||||
ZipFile
|
ZipFile
|
||||||
AbstractTrees
|
AbstractTrees
|
||||||
Reexport
|
Reexport
|
||||||
|
StatsBase
|
||||||
|
|
||||||
# AD
|
# AD
|
||||||
ForwardDiff 0.5.0
|
ForwardDiff 0.5.0
|
||||||
|
@ -26,6 +26,6 @@ deploydocs(
|
|||||||
repo = "github.com/FluxML/Flux.jl.git",
|
repo = "github.com/FluxML/Flux.jl.git",
|
||||||
target = "build",
|
target = "build",
|
||||||
osname = "linux",
|
osname = "linux",
|
||||||
julia = "0.6",
|
julia = "1.0",
|
||||||
deps = nothing,
|
deps = nothing,
|
||||||
make = nothing)
|
make = nothing)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.
|
||||||
|
|
||||||
```
|
```
|
||||||
julia> using Flux: onehot
|
julia> using Flux: onehot, onecold
|
||||||
|
|
||||||
julia> onehot(:b, [:a, :b, :c])
|
julia> onehot(:b, [:a, :b, :c])
|
||||||
3-element Flux.OneHotVector:
|
3-element Flux.OneHotVector:
|
||||||
@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
|
|||||||
true
|
true
|
||||||
```
|
```
|
||||||
|
|
||||||
The inverse is `argmax` (which can take a general probability distribution, as well as just booleans).
|
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> argmax(ans, [:a, :b, :c])
|
julia> onecold(ans, [:a, :b, :c])
|
||||||
:c
|
:c
|
||||||
|
|
||||||
julia> argmax([true, false, false], [:a, :b, :c])
|
julia> onecold([true, false, false], [:a, :b, :c])
|
||||||
:a
|
:a
|
||||||
|
|
||||||
julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c])
|
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
|
||||||
:c
|
:c
|
||||||
```
|
```
|
||||||
|
|
||||||
## Batches
|
## Batches
|
||||||
|
|
||||||
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches.
|
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> using Flux: onehotbatch
|
julia> using Flux: onehotbatch
|
||||||
|
@ -1,18 +1,17 @@
|
|||||||
# Flux: The Julia Machine Learning Library
|
# Flux: The Julia Machine Learning Library
|
||||||
|
|
||||||
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. The whole stack is implemented in clean Julia code (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)) and any part can be tweaked to your liking.
|
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles:
|
||||||
|
|
||||||
# Installation
|
* **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work – and be fast.
|
||||||
|
* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When in doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own.
|
||||||
|
* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models.
|
||||||
|
|
||||||
Install [Julia 0.6.0 or later](https://julialang.org/downloads/), if you haven't already.
|
## Installation
|
||||||
|
|
||||||
```julia
|
Download [Julia 1.0](https://julialang.org/) or later, if you haven't already. You can add Flux from using Julia's package manager, by typing `] add Flux` in the Julia prompt.
|
||||||
Pkg.add("Flux")
|
|
||||||
# Optional but recommended
|
|
||||||
Pkg.update() # Keep your packages up to date
|
|
||||||
Pkg.test("Flux") # Check things installed correctly
|
|
||||||
```
|
|
||||||
|
|
||||||
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
|
If you have CUDA you can also run `] add CuArrays` to get GPU support; see [here](gpu.md) for more details.
|
||||||
|
|
||||||
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.
|
## Learning Flux
|
||||||
|
|
||||||
|
There are several different ways to learn Flux. If you just want to get started writing models, the [model zoo](https://github.com/FluxML/model-zoo/) gives good starting points for many common ones. This documentation provides a reference to all of Flux's APIs, as well as a from-scratch introduction to Flux's take on models and how they work. Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts.
|
||||||
|
@ -134,7 +134,7 @@ All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around
|
|||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> x.tracker
|
julia> x.tracker
|
||||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||||
```
|
```
|
||||||
|
|
||||||
The `Tracker` stores the gradient of a given object, which we've seen before.
|
The `Tracker` stores the gradient of a given object, which we've seen before.
|
||||||
|
@ -10,14 +10,14 @@ using Flux.Tracker
|
|||||||
f(x) = 3x^2 + 2x + 1
|
f(x) = 3x^2 + 2x + 1
|
||||||
|
|
||||||
# df/dx = 6x + 2
|
# df/dx = 6x + 2
|
||||||
f′(x) = Tracker.gradient(f, x)[1]
|
df(x) = Tracker.gradient(f, x)[1]
|
||||||
|
|
||||||
f′(2) # 14.0 (tracked)
|
df(2) # 14.0 (tracked)
|
||||||
|
|
||||||
# d²f/dx² = 6
|
# d²f/dx² = 6
|
||||||
f′′(x) = Tracker.gradient(f′, x)[1]
|
d2f(x) = Tracker.gradient(df, x)[1]
|
||||||
|
|
||||||
f′′(2) # 6.0 (tracked)
|
d2f(2) # 6.0 (tracked)
|
||||||
```
|
```
|
||||||
|
|
||||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
||||||
@ -172,7 +172,7 @@ using Flux
|
|||||||
|
|
||||||
layers = [Dense(10, 5, σ), Dense(5, 2), softmax]
|
layers = [Dense(10, 5, σ), Dense(5, 2), softmax]
|
||||||
|
|
||||||
model(x) = foldl((x, m) -> m(x), x, layers)
|
model(x) = foldl((x, m) -> m(x), layers, init = x)
|
||||||
|
|
||||||
model(rand(10)) # => 2-element vector
|
model(rand(10)) # => 2-element vector
|
||||||
```
|
```
|
||||||
@ -211,7 +211,7 @@ m(5) # => 26
|
|||||||
Flux provides a set of helpers for custom layers, which you can enable by calling
|
Flux provides a set of helpers for custom layers, which you can enable by calling
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
Flux.treelike(Affine)
|
Flux.@treelike Affine
|
||||||
```
|
```
|
||||||
|
|
||||||
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).
|
||||||
|
@ -6,6 +6,8 @@ These core layers form the foundation of almost all neural networks.
|
|||||||
Chain
|
Chain
|
||||||
Dense
|
Dense
|
||||||
Conv
|
Conv
|
||||||
|
MaxPool
|
||||||
|
MeanPool
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recurrent Layers
|
## Recurrent Layers
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Regularisation
|
# Regularisation
|
||||||
|
|
||||||
Applying regularisation to model parameters is straightforward. We just need to
|
Applying regularisation to model parameters is straightforward. We just need to
|
||||||
apply an appropriate regulariser, such as `vecnorm`, to each model parameter and
|
apply an appropriate regulariser, such as `norm`, to each model parameter and
|
||||||
add the result to the overall loss.
|
add the result to the overall loss.
|
||||||
|
|
||||||
For example, say we have a simple regression.
|
For example, say we have a simple regression.
|
||||||
@ -15,12 +15,12 @@ loss(x, y) = crossentropy(softmax(m(x)), y)
|
|||||||
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
|
We can regularise this by taking the (L2) norm of the parameters, `m.W` and `m.b`.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
penalty() = vecnorm(m.W) + vecnorm(m.b)
|
penalty() = norm(m.W) + norm(m.b)
|
||||||
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
|
loss(x, y) = crossentropy(softmax(m(x)), y) + penalty()
|
||||||
```
|
```
|
||||||
|
|
||||||
When working with layers, Flux provides the `params` function to grab all
|
When working with layers, Flux provides the `params` function to grab all
|
||||||
parameters at once. We can easily penalise everything with `sum(vecnorm, params)`.
|
parameters at once. We can easily penalise everything with `sum(norm, params)`.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
julia> params(m)
|
julia> params(m)
|
||||||
@ -28,7 +28,7 @@ julia> params(m)
|
|||||||
param([0.355408 0.533092; … 0.430459 0.171498])
|
param([0.355408 0.533092; … 0.430459 0.171498])
|
||||||
param([0.0, 0.0, 0.0, 0.0, 0.0])
|
param([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
julia> sum(vecnorm, params(m))
|
julia> sum(norm, params(m))
|
||||||
26.01749952921026 (tracked)
|
26.01749952921026 (tracked)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ m = Chain(
|
|||||||
Dense(128, 32, relu),
|
Dense(128, 32, relu),
|
||||||
Dense(32, 10), softmax)
|
Dense(32, 10), softmax)
|
||||||
|
|
||||||
loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
|
loss(x, y) = crossentropy(m(x), y) + sum(norm, params(m))
|
||||||
|
|
||||||
loss(rand(28^2), rand(10))
|
loss(rand(28^2), rand(10))
|
||||||
```
|
```
|
||||||
@ -57,6 +57,6 @@ julia> activations(c, rand(10))
|
|||||||
param([0.0330606, -0.456104])
|
param([0.0330606, -0.456104])
|
||||||
param([0.61991, 0.38009])
|
param([0.61991, 0.38009])
|
||||||
|
|
||||||
julia> sum(vecnorm, ans)
|
julia> sum(norm, ans)
|
||||||
2.639678767773633 (tracked)
|
2.639678767773633 (tracked)
|
||||||
```
|
```
|
||||||
|
@ -1,18 +1,15 @@
|
|||||||
__precompile__()
|
|
||||||
|
|
||||||
module Flux
|
module Flux
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
using Juno, Requires, Reexport
|
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
|
||||||
Dropout, LayerNorm, BatchNorm,
|
Dropout, LayerNorm, BatchNorm,
|
||||||
params, mapleaves, cpu, gpu
|
params, mapleaves, cpu, gpu
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
using NNlib: @fix
|
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
@ -37,6 +34,6 @@ include("layers/normalise.jl")
|
|||||||
|
|
||||||
include("data/Data.jl")
|
include("data/Data.jl")
|
||||||
|
|
||||||
@require CuArrays include("cuda/cuda.jl")
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
module CUDA
|
module CUDA
|
||||||
|
|
||||||
using CuArrays
|
using ..CuArrays
|
||||||
|
|
||||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||||
|
|
||||||
|
@ -1,24 +1,27 @@
|
|||||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||||
cudnnDataType, TensorDesc, FilterDesc
|
cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
|
||||||
|
using LinearAlgebra
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Nothing}
|
||||||
states::CuVector{UInt8}
|
states::CuVector{UInt8}
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr
|
Base.unsafe_convert(::Type{Ptr{Nothing}}, dd::DropoutDesc) = dd.ptr
|
||||||
|
|
||||||
function DropoutDesc(ρ::Real; seed::Integer=0)
|
function DropoutDesc(ρ::Real; seed::Integer=0)
|
||||||
d = [C_NULL]
|
d = [C_NULL]
|
||||||
s = Csize_t[0]
|
s = Csize_t[0]
|
||||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d)
|
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
|
||||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s)
|
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),libcudnn_handle[],s)
|
||||||
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
||||||
desc = DropoutDesc(d[], states)
|
desc = DropoutDesc(d[], states)
|
||||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong),
|
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||||
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
||||||
finalizer(desc, x ->
|
finalizer(desc) do x
|
||||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
|
end
|
||||||
return desc
|
return desc
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -43,10 +46,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
|
|||||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||||
|
|
||||||
function params(w::CuVector, input, hidden, n = 1)
|
function params(w::CuVector, input, hidden, n = 1)
|
||||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
|
||||||
wx = slice(0, (input, hidden*n))
|
wx = slice(0, (input, hidden*n))
|
||||||
wh = slice(length(wx), (hidden, hidden*n))
|
wh = slice(length(wx), (hidden, hidden*n))
|
||||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
|
||||||
(wx, wh), bias
|
(wx, wh), bias
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -57,14 +60,14 @@ mutable struct RNNDesc{T}
|
|||||||
params::CuVector{T}
|
params::CuVector{T}
|
||||||
weights::NTuple{2,CuMatrix{T}}
|
weights::NTuple{2,CuMatrix{T}}
|
||||||
bias::CuVector{T}
|
bias::CuVector{T}
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Nothing}
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr
|
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
|
||||||
|
|
||||||
function rnnParamSize(T, r, input)
|
function rnnParamSize(T, r, input)
|
||||||
size = Csize_t[0]
|
size = Csize_t[0]
|
||||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint),
|
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
|
||||||
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||||
return Int(size[])÷sizeof(T)
|
return Int(size[])÷sizeof(T)
|
||||||
end
|
end
|
||||||
@ -74,26 +77,27 @@ ngates(r::RNNDesc) = ngates(r.mode)
|
|||||||
|
|
||||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||||
d = [C_NULL]
|
d = [C_NULL]
|
||||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
|
||||||
|
|
||||||
dropoutDesc = DropoutDesc(0)
|
dropoutDesc = DropoutDesc(0)
|
||||||
inputMode = LINEAR_INPUT
|
inputMode = LINEAR_INPUT
|
||||||
direction = UNIDIRECTIONAL
|
direction = UNIDIRECTIONAL
|
||||||
algo = RNN_ALGO_STANDARD
|
algo = RNN_ALGO_STANDARD
|
||||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
|
||||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||||
|
|
||||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||||
# TODO: avoid reserve allocation here
|
# TODO: avoid reserve allocation here
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||||
finalizer(rd, x ->
|
finalizer(rd) do x
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
|
end
|
||||||
return rd
|
return rd
|
||||||
end
|
end
|
||||||
|
|
||||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||||
size = Csize_t[0]
|
size = Csize_t[0]
|
||||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}),
|
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
|
||||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||||
return Int(size[])
|
return Int(size[])
|
||||||
end
|
end
|
||||||
@ -110,7 +114,7 @@ getworkspace(r::RNNDesc, seqlen, xdesc) =
|
|||||||
|
|
||||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||||
size = Csize_t[0]
|
size = Csize_t[0]
|
||||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}),
|
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
|
||||||
libcudnn_handle[], r, seqlen, xdesc, size)
|
libcudnn_handle[], r, seqlen, xdesc, size)
|
||||||
return Int(size[])
|
return Int(size[])
|
||||||
end
|
end
|
||||||
@ -119,19 +123,19 @@ function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd
|
|||||||
workspace, reserve=nothing) where T
|
workspace, reserve=nothing) where T
|
||||||
if reserve == nothing
|
if reserve == nothing
|
||||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Ptr{T},
|
Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Csize_t),
|
Ptr{Nothing}, Csize_t),
|
||||||
libcudnn_handle[], rnn, seqlen,
|
libcudnn_handle[], rnn, seqlen,
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, length(workspace))
|
workspace, length(workspace))
|
||||||
else
|
else
|
||||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||||
libcudnn_handle[], rnn, seqlen,
|
libcudnn_handle[], rnn, seqlen,
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, length(workspace), reserve, length(reserve))
|
workspace, length(workspace), reserve, length(reserve))
|
||||||
@ -140,7 +144,7 @@ end
|
|||||||
|
|
||||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||||
|
|
||||||
hDesc(h::Void) = C_NULL, C_NULL
|
hDesc(h::Nothing) = C_NULL, C_NULL
|
||||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||||
function hDesc(h::CuArray)
|
function hDesc(h::CuArray)
|
||||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||||
@ -187,18 +191,18 @@ forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T
|
|||||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint,
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void},
|
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
||||||
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||||
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||||
end
|
end
|
||||||
|
|
||||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||||
# Same as above, any more efficient way?
|
# Same as above, any more efficient way?
|
||||||
dy = dy_ isa Integer ? zeros(y) : dy_
|
dy = dy_ isa Integer ? zero(y) : dy_
|
||||||
yd = xDesc(y)
|
yd = xDesc(y)
|
||||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||||
dh = similar(h)
|
dh = similar(h)
|
||||||
@ -217,19 +221,19 @@ backwardData(rnn, y, dy, dho, hx, reserve) =
|
|||||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||||
workspace, reserve) where T
|
workspace, reserve) where T
|
||||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength
|
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, #x
|
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
||||||
Ptr{Void}, Ptr{T}, #hx
|
Ptr{Nothing}, Ptr{T}, #hx
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, #y
|
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
||||||
Ptr{Void}, Csize_t, #ws
|
Ptr{Nothing}, Csize_t, #ws
|
||||||
Ptr{Void}, Ptr{T}, #dw
|
Ptr{Nothing}, Ptr{T}, #dw
|
||||||
Ptr{Void}, Csize_t), #rs
|
Ptr{Nothing}, Csize_t), #rs
|
||||||
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
||||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||||
end
|
end
|
||||||
|
|
||||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||||
dw = zeros(rnn.params)
|
dw = zero(rnn.params)
|
||||||
cudnnRNNBackwardWeights(rnn, 1,
|
cudnnRNNBackwardWeights(rnn, 1,
|
||||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||||
@ -241,17 +245,17 @@ end
|
|||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
import ..Tracker: TrackedArray
|
import ..Tracker: TrackedArray
|
||||||
using CUDAnative
|
using .CuArrays.CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using .CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
function kernel(dst, src)
|
function kernel(dst, src)
|
||||||
I = @cuindex dst
|
I = @cuindex dst
|
||||||
dst[I...] = src[reverse(I)...]
|
dst[I...] = src[reverse(I)...]
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
blk, thr = cudims(dst)
|
blk, thr = cudims(dst)
|
||||||
@cuda (blk, thr) kernel(dst, src)
|
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||||
return dst
|
return dst
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -324,7 +328,7 @@ end
|
|||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, data(h))
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
|
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -339,6 +343,6 @@ end
|
|||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN,
|
nobacksies(:RNN,
|
||||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||||
dWi.', dWh.', db))
|
transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -11,6 +11,8 @@ function __init__()
|
|||||||
end
|
end
|
||||||
|
|
||||||
include("mnist.jl")
|
include("mnist.jl")
|
||||||
|
export MNIST
|
||||||
|
|
||||||
include("cmudict.jl")
|
include("cmudict.jl")
|
||||||
using .CMUDict
|
using .CMUDict
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ function load()
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
info("Downloading CMUDict dataset")
|
@info "Downloading CMUDict dataset"
|
||||||
mkpath(deps("cmudict"))
|
mkpath(deps("cmudict"))
|
||||||
for x in suffixes
|
for x in suffixes
|
||||||
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||||
@ -24,25 +24,25 @@ end
|
|||||||
|
|
||||||
function phones()
|
function phones()
|
||||||
load()
|
load()
|
||||||
Symbol.(first.(split.(split(readstring(deps("cmudict", "cmudict.phones")),
|
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
|
||||||
"\n", keep = false), "\t")))
|
"\n", keepempty = false), "\t")))
|
||||||
end
|
end
|
||||||
|
|
||||||
function symbols()
|
function symbols()
|
||||||
load()
|
load()
|
||||||
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
|
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
|
||||||
"\n", keep = false))
|
"\n", keepempty = false))
|
||||||
end
|
end
|
||||||
|
|
||||||
function rawdict()
|
function rawdict()
|
||||||
load()
|
load()
|
||||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||||
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
|
filter(!isempty, split.(split(read(deps("cmudict", "cmudict"),String), "\n"))))
|
||||||
end
|
end
|
||||||
|
|
||||||
validword(s) = ismatch(r"^[\w\-\.]+$", s)
|
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
|
||||||
|
|
||||||
cmudict() = filter((s, ps) -> validword(s), rawdict())
|
cmudict() = filter(p -> validword(p.first), rawdict())
|
||||||
|
|
||||||
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
||||||
|
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
module MNIST
|
module MNIST
|
||||||
|
|
||||||
using GZip, Colors
|
using CodecZlib, Colors
|
||||||
|
|
||||||
const Gray = Colors.Gray{Colors.N0f8}
|
const Gray = Colors.Gray{Colors.N0f8}
|
||||||
|
|
||||||
const dir = joinpath(@__DIR__, "../../deps/mnist")
|
const dir = joinpath(@__DIR__, "../../deps/mnist")
|
||||||
|
|
||||||
|
function gzopen(f, file)
|
||||||
|
open(file) do io
|
||||||
|
f(GzipDecompressorStream(io))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function load()
|
function load()
|
||||||
mkpath(dir)
|
mkpath(dir)
|
||||||
cd(dir) do
|
cd(dir) do
|
||||||
@ -14,10 +20,10 @@ function load()
|
|||||||
"t10k-images-idx3-ubyte",
|
"t10k-images-idx3-ubyte",
|
||||||
"t10k-labels-idx1-ubyte"]
|
"t10k-labels-idx1-ubyte"]
|
||||||
isfile(file) && continue
|
isfile(file) && continue
|
||||||
info("Downloading MNIST dataset")
|
@info "Downloading MNIST dataset"
|
||||||
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
||||||
open(file, "w") do io
|
open(file, "w") do io
|
||||||
write(io, GZip.open(read, "$file.gz"))
|
write(io, gzopen(read, "$file.gz"))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -49,7 +55,7 @@ function labelheader(io::IO)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function rawimage(io::IO)
|
function rawimage(io::IO)
|
||||||
img = Array{Gray}(NCOLS, NROWS)
|
img = Array{Gray}(undef, NCOLS, NROWS)
|
||||||
for i in 1:NCOLS, j in 1:NROWS
|
for i in 1:NCOLS, j in 1:NROWS
|
||||||
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
|
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
|
||||||
end
|
end
|
||||||
|
@ -4,8 +4,8 @@ using ZipFile
|
|||||||
using ..Data: deps
|
using ..Data: deps
|
||||||
|
|
||||||
function load()
|
function load()
|
||||||
isfile(deps("sentiment.zip")) || return
|
isfile(deps("sentiment.zip")) && return
|
||||||
info("Downloading sentiment treebank dataset")
|
@info "Downloading sentiment treebank dataset"
|
||||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||||
deps("sentiment.zip"))
|
deps("sentiment.zip"))
|
||||||
end
|
end
|
||||||
@ -14,7 +14,7 @@ getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
|||||||
|
|
||||||
function getfile(name)
|
function getfile(name)
|
||||||
r = ZipFile.Reader(deps("sentiment.zip"))
|
r = ZipFile.Reader(deps("sentiment.zip"))
|
||||||
text = readstring(getfile(r, "trees/$name"))
|
text = read(getfile(r, "trees/$name"), String)
|
||||||
close(r)
|
close(r)
|
||||||
return text
|
return text
|
||||||
end
|
end
|
||||||
@ -26,15 +26,16 @@ totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b))
|
|||||||
totree(t::Expr) = totree_(t.args...)
|
totree(t::Expr) = totree_(t.args...)
|
||||||
|
|
||||||
function parsetree(s)
|
function parsetree(s)
|
||||||
s = replace(s, r"\$", s -> "\\\$")
|
s = replace(s, "\\" => "")
|
||||||
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
s = replace(s, "\$" => "\\\$")
|
||||||
s = replace(s, " ", ", ")
|
s = replace(s, r"[^ \n\(\)]+" => s -> "\"$s\"")
|
||||||
return totree(parse(s))
|
s = replace(s, " " => ", ")
|
||||||
|
return totree(Meta.parse(s))
|
||||||
end
|
end
|
||||||
|
|
||||||
function gettrees(name)
|
function gettrees(name)
|
||||||
load()
|
load()
|
||||||
ss = split(getfile("$name.txt"), '\n', keep = false)
|
ss = split(getfile("$name.txt"), '\n', keepempty = false)
|
||||||
return parsetree.(ss)
|
return parsetree.(ss)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -16,19 +16,19 @@ m(x) == m[2](m[1](x))
|
|||||||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||||
`m[1:3](x)` will calculate the output of the first three layers.
|
`m[1:3](x)` will calculate the output of the first three layers.
|
||||||
"""
|
"""
|
||||||
type Chain
|
struct Chain
|
||||||
layers::Vector{Any}
|
layers::Vector{Any}
|
||||||
Chain(xs...) = new([xs...])
|
Chain(xs...) = new([xs...])
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
|
||||||
@forward Chain.layers Base.start, Base.next, Base.done
|
@forward Chain.layers Base.iterate
|
||||||
|
|
||||||
children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||||
|
|
||||||
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
@ -38,10 +38,7 @@ function Base.show(io::IO, c::Chain)
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
# Seem to need this for `accumulate`; try removing on 0.7
|
activations(c::Chain, x) = accumulate((x, m) -> m(x), c.layers, init = x)
|
||||||
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
|
|
||||||
|
|
||||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dense(in::Integer, out::Integer, σ = identity)
|
Dense(in::Integer, out::Integer, σ = identity)
|
||||||
@ -76,11 +73,11 @@ function Dense(in::Integer, out::Integer, σ = identity;
|
|||||||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||||
end
|
end
|
||||||
|
|
||||||
treelike(Dense)
|
@treelike Dense
|
||||||
|
|
||||||
function (a::Dense)(x)
|
function (a::Dense)(x)
|
||||||
W, b, σ = a.W, a.b, a.σ
|
W, b, σ = a.W, a.b, a.σ
|
||||||
@fix σ.(W*x .+ b)
|
σ.(W*x .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Dense)
|
function Base.show(io::IO, l::Dense)
|
||||||
@ -107,7 +104,7 @@ end
|
|||||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||||
Diagonal(param(initα(in)), param(initβ(in)))
|
Diagonal(param(initα(in)), param(initβ(in)))
|
||||||
|
|
||||||
treelike(Diagonal)
|
@treelike Diagonal
|
||||||
|
|
||||||
function (a::Diagonal)(x)
|
function (a::Diagonal)(x)
|
||||||
α, β = a.α, a.β
|
α, β = a.α, a.β
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
|
||||||
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
|
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
|
||||||
|
|
||||||
expand(N, i::Tuple) = i
|
expand(N, i::Tuple) = i
|
||||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
@ -28,14 +28,14 @@ end
|
|||||||
|
|
||||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||||
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
|
Conv(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
|
||||||
|
|
||||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||||
stride = 1, pad = 0, dilation = 1) where N =
|
stride = 1, pad = 0, dilation = 1) where N =
|
||||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||||
stride = stride, pad = pad, dilation = dilation)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
Flux.treelike(Conv)
|
@treelike Conv
|
||||||
|
|
||||||
function (c::Conv)(x)
|
function (c::Conv)(x)
|
||||||
# TODO: breaks gpu broadcast :(
|
# TODO: breaks gpu broadcast :(
|
||||||
@ -50,3 +50,48 @@ function Base.show(io::IO, l::Conv)
|
|||||||
l.σ == identity || print(io, ", ", l.σ)
|
l.σ == identity || print(io, ", ", l.σ)
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
MaxPool(k)
|
||||||
|
|
||||||
|
Max pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||||
|
|
||||||
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
|
"""
|
||||||
|
struct MaxPool{N}
|
||||||
|
k::NTuple{N,Int}
|
||||||
|
pad::NTuple{N,Int}
|
||||||
|
stride::NTuple{N,Int}
|
||||||
|
end
|
||||||
|
|
||||||
|
MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||||
|
MaxPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||||
|
|
||||||
|
(m::MaxPool)(x) = maxpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||||
|
|
||||||
|
function Base.show(io::IO, m::MaxPool)
|
||||||
|
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
MeanPool(k)
|
||||||
|
|
||||||
|
Mean pooling layer. `k` stands for the size of the window for each dimension of the input.
|
||||||
|
|
||||||
|
Takes the keyword arguments `pad` and `stride`.
|
||||||
|
"""
|
||||||
|
struct MeanPool{N}
|
||||||
|
k::NTuple{N,Int}
|
||||||
|
pad::NTuple{N,Int}
|
||||||
|
stride::NTuple{N,Int}
|
||||||
|
end
|
||||||
|
|
||||||
|
MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
|
||||||
|
MeanPool(k, expand(Val(N), pad), expand(Val(N), stride))
|
||||||
|
|
||||||
|
(m::MeanPool)(x) = meanpool(x, m.k; pad = m.pad, stride = m.stride)
|
||||||
|
|
||||||
|
function Base.show(io::IO, m::MeanPool)
|
||||||
|
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
|
end
|
||||||
|
@ -58,7 +58,7 @@ end
|
|||||||
LayerNorm(h::Integer) =
|
LayerNorm(h::Integer) =
|
||||||
LayerNorm(Diagonal(h))
|
LayerNorm(Diagonal(h))
|
||||||
|
|
||||||
treelike(LayerNorm)
|
@treelike LayerNorm
|
||||||
|
|
||||||
(a::LayerNorm)(x) = a.diag(normalise(x))
|
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ mutable struct BatchNorm{F,V,W,N}
|
|||||||
end
|
end
|
||||||
|
|
||||||
BatchNorm(chs::Integer, λ = identity;
|
BatchNorm(chs::Integer, λ = identity;
|
||||||
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
|
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) =
|
||||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||||
|
|
||||||
@ -130,13 +130,13 @@ function (BN::BatchNorm)(x)
|
|||||||
|
|
||||||
ϵ = data(convert(T, BN.ϵ))
|
ϵ = data(convert(T, BN.ϵ))
|
||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(x, axes)
|
μ = mean(x, dims = axes)
|
||||||
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
|
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
|
||||||
|
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, BN.momentum))
|
mtm = data(convert(T, BN.momentum))
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
|
||||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
let λ = BN.λ
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
gate(h, n) = (1:h) + h*(n-1)
|
gate(h, n) = (1:h) .+ h*(n-1)
|
||||||
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
gate(x::AbstractVector, h, n) = x[gate(h,n)]
|
||||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ function (m::Recur)(xs...)
|
|||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
treelike(Recur, (:cell, :init))
|
@treelike Recur cell, init
|
||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ end
|
|||||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||||
init = glorot_uniform) =
|
init = glorot_uniform) =
|
||||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||||
param(zeros(out)), param(initn(out)))
|
param(zeros(out)), param(init(out)))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||||
@ -94,7 +94,7 @@ end
|
|||||||
|
|
||||||
hidden(m::RNNCell) = m.h
|
hidden(m::RNNCell) = m.h
|
||||||
|
|
||||||
treelike(RNNCell)
|
@treelike RNNCell
|
||||||
|
|
||||||
function Base.show(io::IO, l::RNNCell)
|
function Base.show(io::IO, l::RNNCell)
|
||||||
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
|
||||||
@ -123,13 +123,12 @@ end
|
|||||||
function LSTMCell(in::Integer, out::Integer;
|
function LSTMCell(in::Integer, out::Integer;
|
||||||
init = glorot_uniform)
|
init = glorot_uniform)
|
||||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||||
param(initn(out)), param(initn(out)))
|
param(init(out)), param(init(out)))
|
||||||
cell.b.data[gate(out, 2)] = 1
|
cell.b.data[gate(out, 2)] .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::LSTMCell)(h_, x)
|
function (m::LSTMCell)((h, c), x)
|
||||||
h, c = h_ # TODO: nicer syntax on 0.7
|
|
||||||
b, o = m.b, size(h, 1)
|
b, o = m.b, size(h, 1)
|
||||||
g = m.Wi*x .+ m.Wh*h .+ b
|
g = m.Wi*x .+ m.Wh*h .+ b
|
||||||
input = σ.(gate(g, o, 1))
|
input = σ.(gate(g, o, 1))
|
||||||
@ -143,7 +142,7 @@ end
|
|||||||
|
|
||||||
hidden(m::LSTMCell) = (m.h, m.c)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
|
||||||
treelike(LSTMCell)
|
@treelike LSTMCell
|
||||||
|
|
||||||
Base.show(io::IO, l::LSTMCell) =
|
Base.show(io::IO, l::LSTMCell) =
|
||||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||||
@ -170,7 +169,7 @@ end
|
|||||||
|
|
||||||
GRUCell(in, out; init = glorot_uniform) =
|
GRUCell(in, out; init = glorot_uniform) =
|
||||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||||
param(zeros(out*3)), param(initn(out)))
|
param(zeros(out*3)), param(init(out)))
|
||||||
|
|
||||||
function (m::GRUCell)(h, x)
|
function (m::GRUCell)(h, x)
|
||||||
b, o = m.b, size(h, 1)
|
b, o = m.b, size(h, 1)
|
||||||
@ -178,13 +177,13 @@ function (m::GRUCell)(h, x)
|
|||||||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
||||||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
||||||
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
|
||||||
h′ = (1.-z).*h̃ .+ z.*h
|
h′ = (1 .- z).*h̃ .+ z.*h
|
||||||
return h′, h′
|
return h′, h′
|
||||||
end
|
end
|
||||||
|
|
||||||
hidden(m::GRUCell) = m.h
|
hidden(m::GRUCell) = m.h
|
||||||
|
|
||||||
treelike(GRUCell)
|
@treelike GRUCell
|
||||||
|
|
||||||
Base.show(io::IO, l::GRUCell) =
|
Base.show(io::IO, l::GRUCell) =
|
||||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||||
|
@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
|
|||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||||
|
|
||||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
@ -47,7 +47,7 @@ logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
|||||||
Normalise each column of `x` to mean 0 and standard deviation 1.
|
Normalise each column of `x` to mean 0 and standard deviation 1.
|
||||||
"""
|
"""
|
||||||
function normalise(x::AbstractVecOrMat)
|
function normalise(x::AbstractVecOrMat)
|
||||||
μ′ = mean(x, 1)
|
μ′ = mean(x, dims = 1)
|
||||||
σ′ = std(x, 1, mean = μ′)
|
σ′ = std(x, dims = 1, mean = μ′)
|
||||||
return (x .- μ′) ./ σ′
|
return (x .- μ′) ./ σ′
|
||||||
end
|
end
|
||||||
|
@ -32,20 +32,21 @@ import Adapt.adapt
|
|||||||
|
|
||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
@require CuArrays begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
import CuArrays: CuArray, cudaconvert
|
import .CuArrays: CuArray, cudaconvert
|
||||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
|
||||||
function onehot(l, labels)
|
function onehot(l, labels)
|
||||||
i = findfirst(labels, l)
|
i = something(findfirst(isequal(l), labels), 0)
|
||||||
i > 0 || error("Value $l is not in labels")
|
i > 0 || error("Value $l is not in labels")
|
||||||
OneHotVector(i, length(labels))
|
OneHotVector(i, length(labels))
|
||||||
end
|
end
|
||||||
|
|
||||||
function onehot(l, labels, unk)
|
function onehot(l, labels, unk)
|
||||||
i = findfirst(labels, l)
|
i = something(findfirst(isequal(l), labels), 0)
|
||||||
i > 0 || return onehot(unk, labels)
|
i > 0 || return onehot(unk, labels)
|
||||||
OneHotVector(i, length(labels))
|
OneHotVector(i, length(labels))
|
||||||
end
|
end
|
||||||
@ -53,11 +54,15 @@ end
|
|||||||
onehotbatch(ls, labels, unk...) =
|
onehotbatch(ls, labels, unk...) =
|
||||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||||
|
|
||||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||||
labels[findfirst(y, maximum(y))]
|
|
||||||
|
|
||||||
argmax(y::AbstractMatrix, l...) =
|
onecold(y::AbstractMatrix, labels...) =
|
||||||
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
|
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||||
|
|
||||||
|
function argmax(xs...)
|
||||||
|
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||||
|
return onecold(xs...)
|
||||||
|
end
|
||||||
|
|
||||||
# Ambiguity hack
|
# Ambiguity hack
|
||||||
|
|
||||||
|
@ -2,14 +2,14 @@ module Optimise
|
|||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM, stop, StopException
|
||||||
|
|
||||||
struct Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
Δ::T
|
Δ::T
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
Param(x::AbstractArray) = Param(x, zero(x))
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("interface.jl")
|
include("interface.jl")
|
||||||
@ -17,6 +17,7 @@ include("train.jl")
|
|||||||
|
|
||||||
using Flux.Tracker: TrackedArray
|
using Flux.Tracker: TrackedArray
|
||||||
|
|
||||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
Param(x::TrackedArray) = Param(x.data, x.grad)
|
||||||
|
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -14,7 +14,7 @@ function descentweightdecay(p::Param, η::Real, γ::Real)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function momentum(p::Param, ρ, η)
|
function momentum(p::Param, ρ, η)
|
||||||
v = zeros(p.x)
|
v = zero(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. v = ρ * v - η * p.Δ
|
@. v = ρ * v - η * p.Δ
|
||||||
@. p.Δ = -v
|
@. p.Δ = -v
|
||||||
@ -23,7 +23,7 @@ end
|
|||||||
|
|
||||||
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
||||||
function nesterov(p::Param, ρ, η)
|
function nesterov(p::Param, ρ, η)
|
||||||
v = zeros(p.x)
|
v = zero(p.x)
|
||||||
function ()
|
function ()
|
||||||
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
||||||
@. v = ρ*v - η*p.Δ
|
@. v = ρ*v - η*p.Δ
|
||||||
@ -32,7 +32,7 @@ function nesterov(p::Param, ρ, η)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x)
|
acc = zero(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||||
@. p.Δ *= η / √(acc + ϵ)
|
@. p.Δ *= η / √(acc + ϵ)
|
||||||
@ -40,7 +40,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x) .+ ϵ
|
acc = zero(p.x) .+ ϵ
|
||||||
function ()
|
function ()
|
||||||
@. acc += p.Δ^2
|
@. acc += p.Δ^2
|
||||||
@. p.Δ *= η / √(acc + ϵ)
|
@. p.Δ *= η / √(acc + ϵ)
|
||||||
@ -48,8 +48,8 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||||
acc = zeros(p.x)
|
acc = zero(p.x)
|
||||||
Δacc = zeros(p.x)
|
Δacc = zero(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||||
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||||
@ -58,8 +58,8 @@ function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zero(p.x)
|
||||||
vt = zeros(p.x)
|
vt = zero(p.x)
|
||||||
β1p, β2p = β1, β2
|
β1p, β2p = β1, β2
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@ -71,8 +71,8 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||||||
end
|
end
|
||||||
|
|
||||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zero(p.x)
|
||||||
ut = zeros(p.x)
|
ut = zero(p.x)
|
||||||
β1p = β1
|
β1p = β1
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@ -83,9 +83,9 @@ function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999,
|
|||||||
end
|
end
|
||||||
|
|
||||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zero(p.x)
|
||||||
vt = zeros(p.x) .+ ϵ
|
vt = zero(p.x) .+ ϵ
|
||||||
v̂t = zeros(p.x) .+ ϵ
|
v̂t = zero(p.x) .+ ϵ
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||||
@ -95,8 +95,8 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
|||||||
end
|
end
|
||||||
|
|
||||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
mt = zeros(p.x)
|
mt = zero(p.x)
|
||||||
vt = zeros(p.x)
|
vt = zero(p.x)
|
||||||
β1p, β2p = β1, β2
|
β1p, β2p = β1, β2
|
||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: back!
|
using Flux.Tracker: back!
|
||||||
|
import Base.depwarn
|
||||||
|
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
@ -14,6 +15,25 @@ macro interrupts(ex)
|
|||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
struct StopException <: Exception end
|
||||||
|
"""
|
||||||
|
stop()
|
||||||
|
|
||||||
|
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
|
||||||
|
This would trigger the train loop to stop and exit.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
# Example callback:
|
||||||
|
|
||||||
|
cb = function ()
|
||||||
|
accuracy() > 0.9 && Flux.stop()
|
||||||
|
end
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
function stop()
|
||||||
|
throw(StopException())
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train!(loss, data, opt)
|
train!(loss, data, opt)
|
||||||
|
|
||||||
@ -36,10 +56,21 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
cb = runall(cb)
|
cb = runall(cb)
|
||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
l = loss(d...)
|
try
|
||||||
@interrupts back!(l)
|
l = loss(d...)
|
||||||
opt()
|
@interrupts back!(l)
|
||||||
cb() == :stop && break
|
opt()
|
||||||
|
if cb() == :stop
|
||||||
|
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||||
|
break
|
||||||
|
end
|
||||||
|
catch ex
|
||||||
|
if ex isa StopException
|
||||||
|
break
|
||||||
|
else
|
||||||
|
rethrow(ex)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -59,7 +90,7 @@ hello
|
|||||||
"""
|
"""
|
||||||
macro epochs(n, ex)
|
macro epochs(n, ex)
|
||||||
:(@progress for i = 1:$(esc(n))
|
:(@progress for i = 1:$(esc(n))
|
||||||
info("Epoch $i")
|
@info "Epoch $i"
|
||||||
$(esc(ex))
|
$(esc(ex))
|
||||||
end)
|
end)
|
||||||
end
|
end
|
||||||
|
@ -12,7 +12,7 @@ tracker(x) = nothing
|
|||||||
istracked(x) = tracker(x) ≠ nothing
|
istracked(x) = tracker(x) ≠ nothing
|
||||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||||
grad(x) = grad(tracker(x))
|
grad(x) = grad(tracker(x))
|
||||||
grad(::Void) = nothing
|
grad(::Nothing) = nothing
|
||||||
data(x) = x
|
data(x) = x
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
struct Call{F,As<:Tuple}
|
||||||
@ -20,7 +20,7 @@ struct Call{F,As<:Tuple}
|
|||||||
args::As
|
args::As
|
||||||
end
|
end
|
||||||
|
|
||||||
Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
|
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
||||||
Call() = Call(nothing, ())
|
Call() = Call(nothing, ())
|
||||||
|
|
||||||
# When deserialising, the object_id changes
|
# When deserialising, the object_id changes
|
||||||
@ -35,7 +35,7 @@ mutable struct Tracked{T}
|
|||||||
grad::T
|
grad::T
|
||||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||||
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
|
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
istracked(x::Tracked) = true
|
istracked(x::Tracked) = true
|
||||||
@ -46,7 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
|||||||
|
|
||||||
function _forward end
|
function _forward end
|
||||||
|
|
||||||
function track(f, xs...; kw...)
|
function track(f::F, xs...; kw...) where F
|
||||||
y, back = _forward(f, xs...; kw...)
|
y, back = _forward(f, xs...; kw...)
|
||||||
track(Call(back, tracker.(xs)), y)
|
track(Call(back, tracker.(xs)), y)
|
||||||
end
|
end
|
||||||
@ -77,10 +77,9 @@ include("numeric.jl")
|
|||||||
|
|
||||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||||
the sign of the gradient applied to `x`.
|
the sign of the gradient applied to `x`."""
|
||||||
"""
|
|
||||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||||
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
@grad hook(f, x) = data(x), Δ -> (nothing, f(Δ))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
checkpoint(f, args...)
|
checkpoint(f, args...)
|
||||||
|
@ -1,3 +1,11 @@
|
|||||||
|
import Base: *
|
||||||
|
|
||||||
|
import LinearAlgebra
|
||||||
|
import LinearAlgebra: inv, \, /
|
||||||
|
|
||||||
|
using Statistics
|
||||||
|
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||||
|
|
||||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||||
tracker::Tracked{A}
|
tracker::Tracked{A}
|
||||||
data::A
|
data::A
|
||||||
@ -21,24 +29,20 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|||||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
||||||
|
|
||||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||||
|
|
||||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||||
print(io, "TrackedArray{…,$A}")
|
print(io, "TrackedArray{…,$A}")
|
||||||
|
|
||||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
function Base.summary(io::IO, x::TrackedArray)
|
||||||
if repr
|
print(io, "Tracked ")
|
||||||
print(io, "param(")
|
summary(io, data(x))
|
||||||
Base.showarray(io, data(X), true)
|
|
||||||
print(io, ")")
|
|
||||||
else
|
|
||||||
header && print(io, "Tracked ")
|
|
||||||
Base.showarray(io, data(X), false, header = header)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
||||||
|
|
||||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||||
error("Can't differentiate `setindex!`")
|
error("Can't differentiate `setindex!`")
|
||||||
|
|
||||||
@ -46,7 +50,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
|
|||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims].args
|
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -58,9 +62,11 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||||||
|
|
||||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||||
|
|
||||||
Base.:(==)(x::TrackedArray, y) = data(x) == y
|
for op in [:(==), :≈]
|
||||||
Base.:(==)(y, x::TrackedArray) = y == data(x)
|
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
|
||||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
|
||||||
|
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
|
||||||
|
end
|
||||||
|
|
||||||
# Array Stdlib
|
# Array Stdlib
|
||||||
|
|
||||||
@ -79,37 +85,20 @@ Base.:-(xs::TrackedArray) = track(-, xs)
|
|||||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||||
|
|
||||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||||
|
|
||||||
@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),)
|
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||||
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||||
|
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
|
||||||
|
|
||||||
@grad function repmat(xs, m, n = 1)
|
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||||
repmat(data(xs), m, n), function (Δ)
|
|
||||||
Δ′ = similar(xs)
|
|
||||||
S = size(xs)
|
|
||||||
for (i,v) in enumerate(data(Δ))
|
|
||||||
d1 = divrem(i-1, S[1]*m)
|
|
||||||
x = d1[2] % S[1]+1
|
|
||||||
y = d1[1] % S[2]+1
|
|
||||||
Δ′[x, y] += v
|
|
||||||
end
|
|
||||||
return (nobacksies(:repmat, Δ′), nothing, nothing)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
|
||||||
|
|
||||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
|
||||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||||
Δ′ = zero(xs)
|
Δ′ = zero(xs)
|
||||||
S = size(xs)
|
S = size(xs)
|
||||||
|
|
||||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||||
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
|
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||||
# wrap around based on original size S.
|
# wrap around based on original size S.
|
||||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||||
@ -119,8 +108,8 @@ Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
for f in [:vcat, :hcat]
|
for f in [:vcat, :hcat]
|
||||||
|
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
||||||
@eval begin
|
@eval begin
|
||||||
# This section is a bit of a hack since julia doesn't have a standardised
|
# This section is a bit of a hack since julia doesn't have a standardised
|
||||||
# promotion mechanism for concatenation yet
|
# promotion mechanism for concatenation yet
|
||||||
@ -129,18 +118,18 @@ for f in [:vcat, :hcat]
|
|||||||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||||
# TrackedArray anywhere among the arguments This works as long as base has
|
# TrackedArray anywhere among the arguments This works as long as base has
|
||||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||||
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
|
Base.$f(a::$UArray...) = track($f, a...)
|
||||||
|
|
||||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||||
# first
|
# first
|
||||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||||
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||||
|
|
||||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||||
# second
|
# second
|
||||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||||
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
|
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
|
||||||
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
|
c::$UArray...) =
|
||||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -175,21 +164,23 @@ end
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
||||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||||
|
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||||
|
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||||
|
|
||||||
@grad function cat(dims, Xs...)
|
@grad function cat(Xs...; dims)
|
||||||
cat(dims, data.(Xs)...), function (Δ)
|
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
start = ntuple(i -> 0, Val(ndims(Δ)))
|
||||||
Δs = [begin
|
Δs = [begin
|
||||||
dim_xs = 1:ndims(xs)
|
dim_xs = 1:ndims(xs)
|
||||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
|
||||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
|
||||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
d = reshape(Δ[xs_in_Δ...],size(xs))
|
||||||
start = start .+ till_xs
|
start = start .+ till_xs
|
||||||
d
|
d
|
||||||
end for xs in Xs]
|
end for xs in Xs]
|
||||||
return (nothing, Δs...,)
|
return (Δs...,)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -216,100 +207,123 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||||||
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
||||||
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
inv(A::TrackedArray) = Tracker.track(inv, A)
|
||||||
|
@grad function inv(A)
|
||||||
|
return inv(Tracker.data(A)), function (Δ)
|
||||||
|
Ainv = inv(A)
|
||||||
|
∇A = - Ainv' * Δ * Ainv'
|
||||||
|
return (∇A, )
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# (/) rdivide
|
||||||
|
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
||||||
|
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
||||||
|
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
||||||
|
@grad function (A / B)
|
||||||
|
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
||||||
|
Binv = inv(B)
|
||||||
|
∇B = - Binv' * A' * Δ * Binv'
|
||||||
|
return (Δ * Binv', ∇B)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
||||||
|
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||||
|
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||||
|
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
||||||
|
@grad function (A \ B)
|
||||||
|
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
||||||
|
Ainv = inv(A)
|
||||||
|
∇A = - Ainv' * Δ * B' * Ainv'
|
||||||
|
return (∇A, Ainv' * Δ)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
# Reductions
|
# Reductions
|
||||||
|
|
||||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
|
||||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||||
|
|
||||||
@grad sum(xs, dim...) = sum(data(xs), dim...),
|
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
|
Δ -> (zero(xs) .+ Δ, )
|
||||||
|
|
||||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||||
|
|
||||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||||
@grad prod(xs, dim) = prod(data(xs), dim),
|
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
||||||
Δ -> (nobacksies(:sum,
|
Δ -> (nobacksies(:sum,
|
||||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||||
nothing)
|
nothing)
|
||||||
|
|
||||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||||
|
|
||||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
||||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
|
||||||
|
|
||||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
||||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
||||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
|
||||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
|
||||||
|
|
||||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
import LinearAlgebra: dot
|
||||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
|
||||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||||
|
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||||
|
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||||
|
|
||||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||||
|
|
||||||
# Hacks to get std working
|
# Hacks to get std working
|
||||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||||
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
|
||||||
|
|
||||||
Base.vecnorm(x::TrackedArray, p::Real = 2) =
|
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||||
|
|
||||||
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
|
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
||||||
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
|
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
||||||
|
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
||||||
|
|
||||||
@grad function maximum(xs, r...)
|
@grad function maximum(xs; dims = dims)
|
||||||
maximum(data(xs), r...), function (Δ)
|
maximum(data(xs), dims = dims), function (Δ)
|
||||||
Δ′ = zero(xs)
|
Δ′ = zero(xs)
|
||||||
_, i = findmax(data(xs), r...)
|
_, i = findmax(data(xs), dims = dims)
|
||||||
Δ′[i] = data(Δ)
|
Δ′[i] = data(Δ)
|
||||||
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
|
return (nobacksies(:maximum, Δ′),)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@grad function minimum(xs, r...)
|
|
||||||
minimum(data(xs), r...), function (Δ)
|
@grad function minimum(xs; dims = dims)
|
||||||
|
minimum(data(xs), dims = dims), function (Δ)
|
||||||
Δ′ = zero(xs)
|
Δ′ = zero(xs)
|
||||||
_, i = findmin(data(xs), r...)
|
_, i = findmin(data(xs), dims = dims)
|
||||||
Δ′[i] = data(Δ)
|
Δ′[i] = data(Δ)
|
||||||
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
|
return (nobacksies(:minimum, Δ′),)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# BLAS
|
# BLAS
|
||||||
|
|
||||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
|
||||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
||||||
|
|
||||||
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
|
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||||
@eval begin
|
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
||||||
import Base.$f
|
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
||||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
|
||||||
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
|
|
||||||
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
|
|
||||||
|
|
||||||
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
|
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
||||||
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
|
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
|
||||||
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
|
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
||||||
|
|
||||||
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
|
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
||||||
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
|
x::AbstractVector * y::TrackedVector = track(*, x, y)
|
||||||
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
|
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||||
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||||
|
|
||||||
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
|
||||||
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
|
||||||
|
|
||||||
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
|
||||||
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
|
||||||
|
|
||||||
# NNlib
|
# NNlib
|
||||||
|
|
||||||
@ -352,45 +366,85 @@ end
|
|||||||
|
|
||||||
using ForwardDiff: Dual, partials, value
|
using ForwardDiff: Dual, partials, value
|
||||||
|
|
||||||
dualify(xs, n) = xs
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
|
||||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
|
||||||
|
|
||||||
unbroadcast(x::Tuple, Δ) =
|
unbroadcast(x::AbstractArray, Δ) =
|
||||||
x == size(Δ) ? Δ :
|
size(x) == size(Δ) ? Δ :
|
||||||
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
length(x) == length(Δ) ? trim(x, Δ) :
|
||||||
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||||
|
|
||||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||||
|
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
|
||||||
|
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
|
||||||
|
|
||||||
function getpartial(Δ, x, i)
|
dual(x, p) = x
|
||||||
@inbounds p = getindex(partials(x), i)
|
dual(x::Real, p) = Dual(x, p)
|
||||||
return Δ * p
|
|
||||||
|
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
||||||
|
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
||||||
|
return Δ * f(dargs...).partials[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
||||||
sizes = size.(args)
|
y = broadcast(f, data.(args)...)
|
||||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
eltype(y) <: Real || return y
|
||||||
out = broadcast(f, dargs...)
|
eltype(y) == Bool && return y
|
||||||
eltype(out) <: Dual || return out
|
function back(Δ)
|
||||||
y = value.(out)
|
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
|
||||||
back = function (Δ_)
|
dxs = map(unbroadcast, args, Δargs)
|
||||||
Δ = data(Δ_)
|
return dxs
|
||||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
|
||||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
|
||||||
nobacksies(:broadcast, dxs)
|
|
||||||
end
|
end
|
||||||
# So we can return non-tracked arrays
|
# So we can return non-tracked arrays
|
||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
|
|
||||||
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
|
||||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
|
||||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
|
||||||
|
|
||||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...)
|
struct TrackedStyle <: BroadcastStyle end
|
||||||
|
|
||||||
|
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||||
|
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||||
|
|
||||||
|
# We have to re-build the original broadcast struct to get the appropriate array
|
||||||
|
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
||||||
|
broadcast_rebuild(xs) = data(xs)
|
||||||
|
|
||||||
|
broadcast_rebuild(bc::Broadcasted) =
|
||||||
|
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
||||||
|
|
||||||
|
preprocess(x) = x
|
||||||
|
|
||||||
|
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||||
|
bc1 = Broadcast.flatten(bc)
|
||||||
|
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
||||||
|
∇broadcast(bc2.f, bc1.args...)
|
||||||
|
end
|
||||||
|
|
||||||
|
using Requires
|
||||||
|
|
||||||
|
# https://github.com/FluxML/Flux.jl/issues/353
|
||||||
|
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||||
|
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||||
|
isflat(bc) && return bc
|
||||||
|
args = cat_nested(bc)
|
||||||
|
let makeargs = make_makeargs(bc), f = bc.f
|
||||||
|
newf = @inline function(args::Vararg{Any,N}) where N
|
||||||
|
f(makeargs(args...)...)
|
||||||
|
end
|
||||||
|
return Broadcasted{Style}(newf, args, bc.axes)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||||
|
bc = t[1]
|
||||||
|
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||||
|
let makeargs = make_makeargs(makeargs, bc.args)
|
||||||
|
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||||
|
return @inline function(args::Vararg{Any,N}) where N
|
||||||
|
args1 = makeargs(args...)
|
||||||
|
a, b = headargs(args1...), tailargs(args1...)
|
||||||
|
(f(a...), b...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -26,7 +26,7 @@ function back_(c::Call, Δ)
|
|||||||
foreach(back, c.args, data.(Δs))
|
foreach(back, c.args, data.(Δs))
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(::Call{Void}, Δ) = nothing
|
back_(::Call{Nothing}, Δ) = nothing
|
||||||
|
|
||||||
accum!(x, Δ) = x .+ Δ
|
accum!(x, Δ) = x .+ Δ
|
||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
@ -47,7 +47,7 @@ function back(x::Tracked, Δ)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(::Void, _) = return
|
back(::Nothing, _) = return
|
||||||
|
|
||||||
# Interface methods
|
# Interface methods
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ struct Params
|
|||||||
Params(xs) = new(IdSet(xs))
|
Params(xs) = new(IdSet(xs))
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Params.params Base.start, Base.next, Base.done
|
@forward Params.params Base.iterate, Base.length
|
||||||
|
|
||||||
function Base.show(io::IO, ps::Params)
|
function Base.show(io::IO, ps::Params)
|
||||||
print(io, "Params([")
|
print(io, "Params([")
|
||||||
@ -79,14 +79,16 @@ function Base.show(io::IO, ps::Params)
|
|||||||
end
|
end
|
||||||
|
|
||||||
struct Grads
|
struct Grads
|
||||||
grads::ObjectIdDict
|
grads::IdDict{Any,Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
||||||
|
|
||||||
Grads() = Grads(ObjectIdDict())
|
Grads() = Grads(IdDict())
|
||||||
|
|
||||||
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||||
|
|
||||||
|
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||||
|
|
||||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||||
function Base.getindex(g::Grads, x)
|
function Base.getindex(g::Grads, x)
|
||||||
@ -94,9 +96,8 @@ function Base.getindex(g::Grads, x)
|
|||||||
g[tracker(x)]
|
g[tracker(x)]
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Grads.grads Base.setindex!, Base.haskey
|
|
||||||
|
|
||||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
|
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||||
|
|
||||||
function back_(g::Grads, c::Call, Δ)
|
function back_(g::Grads, c::Call, Δ)
|
||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
@ -105,7 +106,7 @@ function back_(g::Grads, c::Call, Δ)
|
|||||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
back_(g::Grads, ::Call{Nothing}, Δ) = nothing
|
||||||
|
|
||||||
function back(g::Grads, x::Tracked, Δ)
|
function back(g::Grads, x::Tracked, Δ)
|
||||||
x.isleaf && (accum!(g, x, Δ); return)
|
x.isleaf && (accum!(g, x, Δ); return)
|
||||||
@ -119,7 +120,7 @@ function back(g::Grads, x::Tracked, Δ)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(::Grads, ::Void, _) = return
|
back(::Grads, ::Nothing, _) = return
|
||||||
|
|
||||||
function forward(f, ps::Params)
|
function forward(f, ps::Params)
|
||||||
y = f()
|
y = f()
|
||||||
@ -136,7 +137,7 @@ end
|
|||||||
function forward(f, args...)
|
function forward(f, args...)
|
||||||
args = param.(args)
|
args = param.(args)
|
||||||
y, back = forward(() -> f(args...), Params(args))
|
y, back = forward(() -> f(args...), Params(args))
|
||||||
y, Δ -> getindex.(back(Δ), args)
|
y, Δ -> getindex.(Ref(back(Δ)), args)
|
||||||
end
|
end
|
||||||
|
|
||||||
function losscheck(x)
|
function losscheck(x)
|
||||||
@ -152,3 +153,13 @@ function gradient(f, args...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
derivative(f, x) = gradient(f, x)[1]
|
derivative(f, x) = gradient(f, x)[1]
|
||||||
|
|
||||||
|
# Non-nesting versions
|
||||||
|
|
||||||
|
function gradient_(f, xs...)
|
||||||
|
xs = param.(xs)
|
||||||
|
l = f(xs...)
|
||||||
|
losscheck(l)
|
||||||
|
back!(l)
|
||||||
|
grad.(xs)
|
||||||
|
end
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
struct IdSet{T} <: AbstractSet{T}
|
struct IdSet{T} <: AbstractSet{T}
|
||||||
dict::ObjectIdDict
|
dict::IdDict{T,Nothing}
|
||||||
IdSet{T}() where T = new(ObjectIdDict())
|
IdSet{T}() where T = new(IdDict{T,Nothing}())
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.eltype{T}(::IdSet{T}) = T
|
Base.eltype(::IdSet{T}) where T = T
|
||||||
|
|
||||||
IdSet() = IdSet{Any}()
|
IdSet() = IdSet{Any}()
|
||||||
|
|
||||||
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
|
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
|
||||||
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
|
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
|
||||||
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
Base.in(x, s::IdSet) = haskey(s.dict, x)
|
||||||
|
|
||||||
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
|
IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
|
||||||
|
|
||||||
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
IdSet(xs) = IdSet{eltype(xs)}(xs)
|
||||||
|
|
||||||
@ -20,6 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
|||||||
|
|
||||||
@forward IdSet.dict Base.length
|
@forward IdSet.dict Base.length
|
||||||
|
|
||||||
Base.start(s::IdSet) = start(keys(s.dict))
|
function Base.iterate(v::IdSet, state...)
|
||||||
Base.next(s::IdSet, st) = next(keys(s.dict), st)
|
y = Base.iterate(keys(v.dict), state...)
|
||||||
Base.done(s::IdSet, st) = done(keys(s.dict), st)
|
y === nothing && return nothing
|
||||||
|
return (y[1], y[2])
|
||||||
|
end
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
function ngradient(f, xs::AbstractArray...)
|
function ngradient(f, xs::AbstractArray...)
|
||||||
grads = zeros.(xs)
|
grads = zero.(xs)
|
||||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||||
δ = sqrt(eps())
|
δ = sqrt(eps())
|
||||||
tmp = x[i]
|
tmp = x[i]
|
||||||
|
@ -30,8 +30,11 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
|
|||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||||
error("Not implemented: convert tracked $S to tracked $T")
|
error("Not implemented: convert tracked $S to tracked $T")
|
||||||
|
|
||||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
for op in [:(==), :≈, :<]
|
||||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
||||||
|
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
||||||
|
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
||||||
|
end
|
||||||
|
|
||||||
Base.eps(x::TrackedReal) = eps(data(x))
|
Base.eps(x::TrackedReal) = eps(data(x))
|
||||||
|
|
||||||
@ -115,3 +118,7 @@ end
|
|||||||
function back_(c::Call{typeof(collect)}, Δ)
|
function back_(c::Call{typeof(collect)}, Δ)
|
||||||
foreach(back, c.args[1], data(Δ))
|
foreach(back, c.args[1], data(Δ))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
||||||
|
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
||||||
|
end
|
||||||
|
@ -7,16 +7,27 @@ mapchildren(f, x) = x
|
|||||||
children(x::Tuple) = x
|
children(x::Tuple) = x
|
||||||
mapchildren(f, x::Tuple) = map(f, x)
|
mapchildren(f, x::Tuple) = map(f, x)
|
||||||
|
|
||||||
function treelike(T, fs = fieldnames(T))
|
function treelike(m::Module, T, fs = fieldnames(T))
|
||||||
@eval current_module() begin
|
@eval m begin
|
||||||
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
Flux.children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||||
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
Flux.mapchildren(f, x::$T) = $T(f.($children(x))...)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function treelike(T, fs = fieldnames(T))
|
||||||
|
Base.depwarn("`treelike(T)` is deprecated, use `@treelike T`", :treelike)
|
||||||
|
treelike(Base._current_module(), T, fs)
|
||||||
|
end
|
||||||
|
|
||||||
|
macro treelike(T, fs = nothing)
|
||||||
|
fs == nothing || isexpr(fs, :tuple) || error("@treelike T (a, b)")
|
||||||
|
fs = fs == nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
|
||||||
|
:(treelike(@__MODULE__, $(esc(T)), $(fs...)))
|
||||||
|
end
|
||||||
|
|
||||||
isleaf(x) = isempty(children(x))
|
isleaf(x) = isempty(children(x))
|
||||||
|
|
||||||
function mapleaves(f, x; cache = ObjectIdDict())
|
function mapleaves(f, x; cache = IdDict())
|
||||||
haskey(cache, x) && return cache[x]
|
haskey(cache, x) && return cache[x]
|
||||||
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
|
||||||
end
|
end
|
||||||
@ -43,7 +54,7 @@ function loadparams!(m, xs)
|
|||||||
for (p, x) in zip(params(m), xs)
|
for (p, x) in zip(params(m), xs)
|
||||||
size(p) == size(x) ||
|
size(p) == size(x) ||
|
||||||
error("Expected param size $(size(p)), got $(size(x))")
|
error("Expected param size $(size(p)), got $(size(x))")
|
||||||
copy!(data(p), data(x))
|
copyto!(data(p), data(x))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -53,7 +64,7 @@ cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
|||||||
|
|
||||||
gpu_adaptor = identity
|
gpu_adaptor = identity
|
||||||
|
|
||||||
@require CuArrays begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
global gpu_adaptor = CuArrays.cu
|
global gpu_adaptor = CuArrays.cu
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Arrays
|
# Arrays
|
||||||
|
|
||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||||
|
|
||||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||||||
end
|
end
|
||||||
|
|
||||||
cooldown = false
|
cooldown = false
|
||||||
@schedule try
|
@async try
|
||||||
while (sleep(timeout); later != nothing)
|
while (sleep(timeout); later != nothing)
|
||||||
later()
|
later()
|
||||||
later = nothing
|
later = nothing
|
||||||
@ -145,7 +145,7 @@ function jacobian(m,x)
|
|||||||
y = m(xp)
|
y = m(xp)
|
||||||
k = length(y)
|
k = length(y)
|
||||||
n = length(x)
|
n = length(x)
|
||||||
J = Matrix{eltype(x)}(n,k)
|
J = Matrix{eltype(x)}(undef,n,k)
|
||||||
for i = 1:k
|
for i = 1:k
|
||||||
Flux.back!(y[i]) # Populate gradient accumulator
|
Flux.back!(y[i]) # Populate gradient accumulator
|
||||||
J[:,i] = xp.grad
|
J[:,i] = xp.grad
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
using Flux, Flux.Tracker, CuArrays, Test
|
||||||
using Flux: gpu
|
using Flux: gpu
|
||||||
|
|
||||||
info("Testing Flux/GPU")
|
@info "Testing GPU Support"
|
||||||
|
|
||||||
@testset "CuArrays" begin
|
@testset "CuArrays" begin
|
||||||
|
|
||||||
@ -14,6 +14,7 @@ cx = gpu(x)
|
|||||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||||
|
@test (cx .+ 1) isa CuArray
|
||||||
|
|
||||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
@ -25,10 +26,13 @@ x = [1,2,3]
|
|||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||||
|
|
||||||
# Fails in Pkg.test ffs
|
xs = param(rand(5,5))
|
||||||
# c = gpu(Conv((2,2),3=>4))
|
ys = Flux.onehotbatch(1:5,1:5)
|
||||||
# l = c(gpu(rand(10,10,3,2)))
|
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||||
# Flux.back!(sum(l))
|
|
||||||
|
c = gpu(Conv((2,2),3=>4))
|
||||||
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
|
Flux.back!(sum(l))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using Flux, CuArrays, Base.Test
|
using Flux, CuArrays, Test
|
||||||
|
|
||||||
info("Testing Flux/CUDNN")
|
@info "Testing Flux/CUDNN"
|
||||||
|
|
||||||
@testset "RNN" begin
|
@testset "RNN" begin
|
||||||
@testset for R in [RNN, GRU, LSTM]
|
@testset for R in [RNN, GRU, LSTM]
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
using Flux.Data
|
using Flux.Data
|
||||||
using Base.Test
|
using Test
|
||||||
|
|
||||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||||
|
|
||||||
@test length(CMUDict.phones()) == 39
|
@test length(CMUDict.phones()) == 39
|
||||||
|
|
||||||
@test length(CMUDict.symbols()) == 84
|
@test length(CMUDict.symbols()) == 84
|
||||||
|
|
||||||
|
@test MNIST.images()[1] isa Matrix
|
||||||
|
@test MNIST.labels() isa Vector{Int64}
|
||||||
|
|
||||||
|
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
|
||||||
|
23
test/layers/conv.jl
Normal file
23
test/layers/conv.jl
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
using Flux, Test
|
||||||
|
using Flux: maxpool, meanpool
|
||||||
|
|
||||||
|
@testset "Pooling" begin
|
||||||
|
x = randn(10, 10, 3, 2)
|
||||||
|
mp = MaxPool((2, 2))
|
||||||
|
@test mp(x) == maxpool(x, (2,2))
|
||||||
|
mp = MeanPool((2, 2))
|
||||||
|
@test mp(x) == meanpool(x, (2,2))
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "CNN" begin
|
||||||
|
r = zeros(28, 28, 1, 5)
|
||||||
|
m = Chain(
|
||||||
|
Conv((2, 2), 1=>16, relu),
|
||||||
|
MaxPool((2,2)),
|
||||||
|
Conv((2, 2), 16=>8, relu),
|
||||||
|
MaxPool((2,2)),
|
||||||
|
x -> reshape(x, :, size(x, 4)),
|
||||||
|
Dense(288, 10), softmax)
|
||||||
|
|
||||||
|
@test size(m(r)) == (10, 5)
|
||||||
|
end
|
@ -4,7 +4,7 @@ using Flux: testmode!
|
|||||||
x = [1.,2.,3.]
|
x = [1.,2.,3.]
|
||||||
@test x == testmode!(Dropout(0.1))(x)
|
@test x == testmode!(Dropout(0.1))(x)
|
||||||
@test x == Dropout(0)(x)
|
@test x == Dropout(0)(x)
|
||||||
@test zeros(x) == Dropout(1)(x)
|
@test zero(x) == Dropout(1)(x)
|
||||||
|
|
||||||
x = rand(100)
|
x = rand(100)
|
||||||
m = Dropout(0.9)
|
m = Dropout(0.9)
|
||||||
@ -53,17 +53,17 @@ end
|
|||||||
# .1 * 4 + 0 = .4
|
# .1 * 4 + 0 = .4
|
||||||
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||||
|
|
||||||
# julia> .1 .* std(x, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
# 2×1 Array{Float64,2}:
|
# 2×1 Array{Float64,2}:
|
||||||
# 1.14495
|
# 1.14495
|
||||||
# 1.14495
|
# 1.14495
|
||||||
@test m.σ ≈ .1 .* std(x.data, 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
@test m.σ ≈ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
|
|
||||||
testmode!(m)
|
testmode!(m)
|
||||||
@test !m.active
|
@test !m.active
|
||||||
|
|
||||||
x′ = m(x).data
|
x′ = m(x).data
|
||||||
@test x′[1] ≈ (1 - 0.3) / 1.1449489742783179
|
@test x′[1] ≈ (1 .- 0.3) / 1.1449489742783179
|
||||||
end
|
end
|
||||||
|
|
||||||
# with activation function
|
# with activation function
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using Base.Test
|
using Test
|
||||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||||
σ, binarycrossentropy, logitbinarycrossentropy
|
σ, binarycrossentropy, logitbinarycrossentropy
|
||||||
|
|
||||||
@ -42,8 +42,8 @@ const ϵ = 1e-7
|
|||||||
|
|
||||||
logŷ, y = randn(3), rand(3)
|
logŷ, y = randn(3), rand(3)
|
||||||
@testset "binarycrossentropy" begin
|
@testset "binarycrossentropy" begin
|
||||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))
|
||||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
|
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "logitbinarycrossentropy" begin
|
@testset "logitbinarycrossentropy" begin
|
||||||
|
13
test/onehot.jl
Normal file
13
test/onehot.jl
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
using Flux:onecold
|
||||||
|
using Test
|
||||||
|
|
||||||
|
@testset "onecold" begin
|
||||||
|
a = [1, 2, 5, 3.]
|
||||||
|
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
|
||||||
|
labels = ['A', 'B', 'C', 'D']
|
||||||
|
|
||||||
|
@test onecold(a) == 3
|
||||||
|
@test onecold(A) == [3, 1, 4]
|
||||||
|
@test onecold(a, labels) == 'C'
|
||||||
|
@test onecold(A, labels) == ['C', 'A', 'D']
|
||||||
|
end
|
@ -1,6 +1,6 @@
|
|||||||
using Flux.Optimise
|
using Flux.Optimise
|
||||||
using Flux.Tracker
|
using Flux.Tracker
|
||||||
|
using Test
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||||
@ -23,7 +23,7 @@ end
|
|||||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||||
Iterators.repeated((), 100),
|
Iterators.repeated((), 100),
|
||||||
()->(),
|
()->(),
|
||||||
cb = Flux.throttle(() -> (i > 3 && :stop), 1))
|
cb = Flux.throttle(() -> (i > 3 && stop()), 1))
|
||||||
|
|
||||||
@test 3 < i < 50
|
@test 3 < i < 50
|
||||||
end
|
end
|
||||||
|
@ -1,17 +1,46 @@
|
|||||||
using Flux, Base.Test
|
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
|
||||||
|
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
|
||||||
|
if Base.JLOptions().check_bounds == 1
|
||||||
|
file = @__FILE__
|
||||||
|
run(```
|
||||||
|
$(Base.julia_cmd())
|
||||||
|
--color=$(Base.have_color ? "yes" : "no")
|
||||||
|
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
|
||||||
|
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
|
||||||
|
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
|
||||||
|
$(file)
|
||||||
|
```)
|
||||||
|
exit()
|
||||||
|
end
|
||||||
|
|
||||||
srand(0)
|
using Flux, Test, Random
|
||||||
|
using Random
|
||||||
|
|
||||||
|
Random.seed!(0)
|
||||||
|
|
||||||
|
# So we can use the system CuArrays
|
||||||
|
insert!(LOAD_PATH, 2, "@v#.#")
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
|
@info "Testing Basics"
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("tracker.jl")
|
include("onehot.jl")
|
||||||
include("layers/normalisation.jl")
|
|
||||||
include("layers/stateless.jl")
|
|
||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
|
||||||
if Base.find_in_path("CuArrays") ≠ nothing
|
@info "Testing Layers"
|
||||||
|
|
||||||
|
include("layers/normalisation.jl")
|
||||||
|
include("layers/stateless.jl")
|
||||||
|
include("layers/conv.jl")
|
||||||
|
|
||||||
|
@info "Running Gradient Checks"
|
||||||
|
|
||||||
|
include("tracker.jl")
|
||||||
|
|
||||||
|
if Base.find_package("CuArrays") != nothing
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
127
test/tracker.jl
127
test/tracker.jl
@ -1,22 +1,27 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux
|
||||||
|
using Flux.Tracker, Test, NNlib
|
||||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
using Printf: @sprintf
|
||||||
|
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
|
||||||
|
using Statistics: mean, std
|
||||||
|
using Random
|
||||||
|
# using StatsBase
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||||
|
|
||||||
@testset "Tracker" begin
|
@testset "Tracker" begin
|
||||||
|
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||||
|
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||||
|
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||||
|
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
||||||
@test gradtest(x -> prod(x), (3,4,5))
|
@test gradtest(x -> prod(x), (3,4,5))
|
||||||
|
|
||||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||||
@ -28,7 +33,6 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||||
|
|
||||||
@test gradtest(x -> x', rand(5))
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
function promotiontest(f, A, B, C)
|
function promotiontest(f, A, B, C)
|
||||||
r0 = f(A, B, C)
|
r0 = f(A, B, C)
|
||||||
r1 = f(param(A), B, C)
|
r1 = f(param(A), B, C)
|
||||||
@ -48,8 +52,8 @@ function promotiontest(f, A, B, C)
|
|||||||
end
|
end
|
||||||
|
|
||||||
@testset "concat" begin
|
@testset "concat" begin
|
||||||
cat1(x...) = cat(1, x...)
|
cat1(x...) = cat(x..., dims = 1)
|
||||||
cat2(x...) = cat(2, x...)
|
cat2(x...) = cat(x..., dims = 2)
|
||||||
|
|
||||||
@testset for vcatf in [vcat, cat1]
|
@testset for vcatf in [vcat, cat1]
|
||||||
@test gradtest(vcatf, rand(5), rand(3))
|
@test gradtest(vcatf, rand(5), rand(3))
|
||||||
@ -61,6 +65,7 @@ end
|
|||||||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@testset for hcatf in [hcat, cat2]
|
@testset for hcatf in [hcat, cat2]
|
||||||
@test gradtest(hcatf, rand(5), rand(5))
|
@test gradtest(hcatf, rand(5), rand(5))
|
||||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||||
@ -71,17 +76,17 @@ end
|
|||||||
@test gradtest(hcatf, rand(5), rand(5,2))
|
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||||
@test gradtest(catf, rand(5))
|
@test gradtest(catf, rand(5))
|
||||||
@test gradtest(catf, rand(5)')
|
@test gradtest(catf, rand(5)')
|
||||||
@test gradtest(catf, rand(2,5))
|
@test gradtest(catf, rand(2,5))
|
||||||
@test gradtest(catf, rand(2,5,3))
|
@test gradtest(catf, rand(2,5,3))
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||||
|
|
||||||
@testset "cat($dim, ...)" for dim in 3:5
|
@testset "cat($dim, ...)" for dim in 3:5
|
||||||
catdim = (x...) -> cat(dim, x...)
|
catdim = (x...) -> cat(x..., dims = dim)
|
||||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||||
@ -89,12 +94,12 @@ end
|
|||||||
|
|
||||||
@test !isa(vcat(rand(2)), TrackedArray)
|
@test !isa(vcat(rand(2)), TrackedArray)
|
||||||
@test !isa(hcat(rand(2)), TrackedArray)
|
@test !isa(hcat(rand(2)), TrackedArray)
|
||||||
@test !isa(cat(1,rand(2)), TrackedArray)
|
@test !isa(cat(rand(2), dims=1), TrackedArray)
|
||||||
|
|
||||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||||
|
|
||||||
@testset "promotiontest" begin
|
@testset "promotiontest" begin
|
||||||
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||||
@ -105,16 +110,14 @@ end
|
|||||||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||||
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||||
|
|
||||||
# TODO unreliable
|
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
||||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
|
||||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
|
||||||
|
|
||||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||||
|
|
||||||
@ -124,54 +127,59 @@ end
|
|||||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||||
|
|
||||||
@test gradtest(diagm, rand(3))
|
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||||
|
|
||||||
|
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||||
|
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||||
|
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
||||||
|
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
||||||
|
|
||||||
@testset "mean" begin
|
@testset "mean" begin
|
||||||
@test gradtest(mean, rand(2, 3))
|
@test gradtest(mean, rand(2, 3))
|
||||||
|
|
||||||
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
||||||
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
||||||
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
||||||
|
|
||||||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "maximum" begin
|
@testset "maximum" begin
|
||||||
@test gradtest(maximum, rand(2, 3))
|
@test gradtest(maximum, rand(2, 3))
|
||||||
|
|
||||||
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
||||||
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
||||||
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
||||||
|
|
||||||
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "minimum" begin
|
@testset "minimum" begin
|
||||||
@test gradtest(minimum, rand(2, 3))
|
@test gradtest(minimum, rand(2, 3))
|
||||||
|
|
||||||
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
||||||
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
||||||
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
||||||
|
|
||||||
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest(x -> std(x), rand(5,5))
|
@test gradtest(x -> std(x), rand(5,5))
|
||||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||||
|
|
||||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||||
@test gradtest(dot, rand(5), rand(5))
|
@test gradtest(dot, rand(5), rand(5))
|
||||||
|
|
||||||
@test gradtest(vecnorm, rand(5))
|
@test gradtest(norm, rand(5))
|
||||||
|
|
||||||
@test gradtest(rand(5)) do x
|
@test gradtest(rand(5)) do x
|
||||||
y = x.^2
|
y = x.^2
|
||||||
2y + x
|
2y + x
|
||||||
end
|
end
|
||||||
|
|
||||||
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||||
|
|
||||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||||
@ -179,9 +187,30 @@ end
|
|||||||
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||||
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2))
|
||||||
|
|
||||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
@testset "equality & order" begin
|
||||||
|
# TrackedReal
|
||||||
|
@test param(2)^2 == param(4)
|
||||||
|
@test param(2)^2 == 4
|
||||||
|
@test 4 == param(2)^2
|
||||||
|
|
||||||
@test param(2)^2 == 4.0
|
@test param(2)^2 ≈ param(4)
|
||||||
|
@test param(2)^2 ≈ 4
|
||||||
|
@test 4 ≈ param(2)^2
|
||||||
|
|
||||||
|
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||||
|
@test (param([1,2,3]) .<= 2) == [true, true, false]
|
||||||
|
@test (2 .> param([1,2,3])) == [true, false, false]
|
||||||
|
@test (2 .>= param([1,2,3])) == [true, true, false]
|
||||||
|
|
||||||
|
# TrackedArray
|
||||||
|
@test param([1,2,3]).^2 == param([1,4,9])
|
||||||
|
@test [1,2,3].^2 == param([1,4,9])
|
||||||
|
@test param([1,2,3]).^2 == [1,4,9]
|
||||||
|
|
||||||
|
@test param([1,2,3]).^2 ≈ param([1,4,9])
|
||||||
|
@test [1,2,3].^2 ≈ param([1,4,9])
|
||||||
|
@test param([1,2,3]).^2 ≈ [1,4,9]
|
||||||
|
end
|
||||||
|
|
||||||
@testset "reshape" begin
|
@testset "reshape" begin
|
||||||
x = reshape(param(rand(2,2,2)), 4, 2)
|
x = reshape(param(rand(2,2,2)), 4, 2)
|
||||||
@ -211,14 +240,11 @@ end
|
|||||||
@testset "Fallbacks" begin
|
@testset "Fallbacks" begin
|
||||||
xs = param([1 2; 3 4])
|
xs = param([1 2; 3 4])
|
||||||
@test similar(xs) isa Matrix{Float64}
|
@test similar(xs) isa Matrix{Float64}
|
||||||
# Remove this test if we do LowerTriangular properly
|
|
||||||
L = LowerTriangular(xs)
|
|
||||||
@test L*L' isa Matrix{TrackedReal{Float64}}
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||||
|
|
||||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||||
|
|
||||||
b = param(rand())
|
b = param(rand())
|
||||||
Tracker.back!(b)
|
Tracker.back!(b)
|
||||||
@ -231,6 +257,11 @@ Tracker.back!(b)
|
|||||||
z = xy[1]*xy[2]
|
z = xy[1]*xy[2]
|
||||||
back!(z)
|
back!(z)
|
||||||
@test grad.((x,y)) == (3, 2)
|
@test grad.((x,y)) == (3, 2)
|
||||||
|
|
||||||
|
@test Tracker.gradient(2, 3) do x, y
|
||||||
|
xy = Tracker.collect([x, y])
|
||||||
|
xy[1]*xy[2]
|
||||||
|
end == (3, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
# Gradient Hooks
|
# Gradient Hooks
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
using Flux
|
||||||
|
using Flux: throttle, jacobian, initn, glorot_uniform, glorot_normal
|
||||||
|
using StatsBase: std
|
||||||
|
using Random
|
||||||
|
using Test
|
||||||
|
|
||||||
@testset "Throttle" begin
|
@testset "Throttle" begin
|
||||||
@testset "default behaviour" begin
|
@testset "default behaviour" begin
|
||||||
a = []
|
a = []
|
||||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
||||||
f()
|
f()
|
||||||
f()
|
f()
|
||||||
f()
|
f()
|
||||||
@ -13,7 +17,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
|||||||
|
|
||||||
@testset "leading behaviour" begin
|
@testset "leading behaviour" begin
|
||||||
a = []
|
a = []
|
||||||
f = throttle(()->push!(a, now()), 1, leading=true, trailing=false)
|
f = throttle(()->push!(a, time()), 1, leading=true, trailing=false)
|
||||||
f()
|
f()
|
||||||
@test length(a) == 1
|
@test length(a) == 1
|
||||||
f()
|
f()
|
||||||
@ -25,7 +29,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
|||||||
|
|
||||||
@testset "trailing behaviour" begin
|
@testset "trailing behaviour" begin
|
||||||
a = []
|
a = []
|
||||||
f = throttle(()->push!(a, now()), 1, leading=false, trailing=true)
|
f = throttle(()->push!(a, time()), 1, leading=false, trailing=true)
|
||||||
f()
|
f()
|
||||||
@test length(a) == 0
|
@test length(a) == 0
|
||||||
f()
|
f()
|
||||||
@ -59,7 +63,7 @@ end
|
|||||||
|
|
||||||
@testset "Initialization" begin
|
@testset "Initialization" begin
|
||||||
# Set random seed so that these tests don't fail randomly
|
# Set random seed so that these tests don't fail randomly
|
||||||
srand(0)
|
Random.seed!(0)
|
||||||
# initn() should yield a kernel with stddev ~= 1e-2
|
# initn() should yield a kernel with stddev ~= 1e-2
|
||||||
v = initn(10, 10)
|
v = initn(10, 10)
|
||||||
@test std(v) > 0.9*1e-2
|
@test std(v) > 0.9*1e-2
|
||||||
|
Loading…
Reference in New Issue
Block a user