Merge branch 'master' into nadam-opt
This commit is contained in:
commit
4a24b69976
26
README.md
26
README.md
|
@ -1,6 +1,8 @@
|
|||
# Флукс
|
||||
<p align="center">
|
||||
<img width="400px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/logo.png"/>
|
||||
</p>
|
||||
|
||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/)
|
||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/) [](https://doi.org/10.21105/joss.00602)
|
||||
|
||||
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
||||
|
||||
|
@ -10,6 +12,18 @@ julia> Pkg.add("Flux")
|
|||
|
||||
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||
|
||||
If you use Flux in research, please cite the following paper:
|
||||
|
||||
```
|
||||
@article{innes:2018,
|
||||
author = {Mike Innes},
|
||||
title = {Flux: Elegant Machine Learning with Julia},
|
||||
journal = {Journal of Open Source Software},
|
||||
year = {2018},
|
||||
doi = {10.21105/joss.00602},
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
Flux has powerful high-level features, and common architectures can be defined in a few lines.
|
||||
|
@ -36,7 +50,7 @@ b = param(randn(2))
|
|||
y(x) = σ.(W * x .+ b)
|
||||
```
|
||||
|
||||
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
|
||||
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels with [CUDAnative](https://github.com/JuliaGPU/CUDAnative.jl)! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
|
||||
|
||||
```julia
|
||||
function gpu_add(a, b, c)
|
||||
|
@ -77,3 +91,9 @@ For general questions and help, check out Julia's [community forum](https://disc
|
|||
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
|
||||
|
||||
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
|
||||
|
||||
## Related Packages
|
||||
|
||||
Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models.
|
||||
|
||||
[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets.
|
||||
|
|
|
@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib],
|
|||
"One-Hot Encoding" => "data/onehot.md",
|
||||
"GPU Support" => "gpu.md",
|
||||
"Saving & Loading" => "saving.md",
|
||||
"Internals" =>
|
||||
["Backpropagation" => "internals/tracker.md"],
|
||||
"Community" => "community.md"])
|
||||
|
||||
deploydocs(
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
# Flux.Tracker
|
||||
|
||||
Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module.
|
||||
|
||||
```julia
|
||||
julia> using Flux.Tracker
|
||||
```
|
||||
|
||||
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
|
||||
|
||||
```julia
|
||||
julia> W = param([1 2; 3 4])
|
||||
Tracked 2×2 Array{Float64,2}:
|
||||
1.0 2.0
|
||||
3.0 4.0
|
||||
|
||||
julia> x = param([5, 6])
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
5.0
|
||||
6.0
|
||||
|
||||
julia> y = W*x
|
||||
Tracked 2-element Array{Float64,1}:
|
||||
17.0
|
||||
39.0
|
||||
```
|
||||
|
||||
The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays:
|
||||
|
||||
```julia
|
||||
julia> Tracker.back!(y, [1, -1])
|
||||
|
||||
julia> W.grad
|
||||
2×2 Array{Float64,2}:
|
||||
5.0 6.0
|
||||
-5.0 -6.0
|
||||
|
||||
julia> x.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
-2.0
|
||||
```
|
||||
|
||||
## Internals
|
||||
|
||||
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
|
||||
|
||||
```julia
|
||||
julia> x.tracker
|
||||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
|
||||
```
|
||||
|
||||
The `Tracker` stores the value and gradient of a given object, which we've seen before.
|
||||
|
||||
```julia
|
||||
julia> x.tracker.data
|
||||
2-element Array{Float64,1}:
|
||||
5.0
|
||||
6.0
|
||||
|
||||
julia> x.tracker.grad
|
||||
2-element Array{Float64,1}:
|
||||
-2.0
|
||||
-2.0
|
||||
```
|
||||
|
||||
The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this:
|
||||
|
||||
```julia
|
||||
julia> Tracker.Call(+, 1, 2)
|
||||
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
|
||||
```
|
||||
|
||||
In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`.
|
||||
|
||||
```julia
|
||||
julia> y.tracker.f
|
||||
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
|
||||
```
|
||||
|
||||
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*).
|
||||
|
||||
When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling
|
||||
|
||||
```julia
|
||||
Tracker.back(*, [1, -1], W, x)
|
||||
```
|
||||
|
||||
which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
|
||||
|
||||
## Custom Gradients
|
||||
|
||||
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
|
||||
|
||||
```julia
|
||||
julia> minus(a, b) = a - b
|
||||
```
|
||||
|
||||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||
|
||||
```julia
|
||||
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus (generic function with 2 methods)
|
||||
```
|
||||
|
||||
`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
|
||||
|
||||
```julia
|
||||
julia> a, b = param([6,5,4]), param([1,2,3])
|
||||
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
|
||||
|
||||
julia> c = minus(a, b)
|
||||
Tracked 3-element Array{Float64,1}:
|
||||
5.0
|
||||
3.0
|
||||
1.0
|
||||
|
||||
julia> c.tracker.f
|
||||
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
|
||||
```
|
||||
|
||||
Finally, we have to specify the gradient of `minus`.
|
||||
|
||||
```julia
|
||||
julia> Tracker.back(::typeof(minus), Δ, a, b) =
|
||||
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
|
||||
```
|
||||
|
||||
`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
|
||||
|
||||
```julia
|
||||
julia> Flux.back!(c, 1)
|
||||
|
||||
julia> a.grad
|
||||
3-element Array{Float64,1}:
|
||||
1.0
|
||||
1.0
|
||||
1.0
|
||||
|
||||
julia> b.grad
|
||||
3-element Array{Float64,1}:
|
||||
-1.0
|
||||
-1.0
|
||||
-1.0
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||
|
||||
```julia
|
||||
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
|
||||
```
|
||||
|
||||
`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.
|
|
@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks.
|
|||
```@docs
|
||||
Chain
|
||||
Dense
|
||||
Conv2D
|
||||
Conv
|
||||
```
|
||||
|
||||
## Recurrent Layers
|
||||
|
@ -15,6 +15,7 @@ Much like the core layers above, but can be used to process sequence data (as we
|
|||
```@docs
|
||||
RNN
|
||||
LSTM
|
||||
GRU
|
||||
Flux.Recur
|
||||
```
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ add the result to the overall loss.
|
|||
For example, say we have a simple regression.
|
||||
|
||||
```julia
|
||||
using Flux: crossentropy
|
||||
m = Dense(10, 5)
|
||||
loss(x, y) = crossentropy(softmax(m(x)), y)
|
||||
```
|
||||
|
|
|
@ -17,14 +17,6 @@
|
|||
url = {http://arxiv.org/abs/1712.03112},
|
||||
}
|
||||
|
||||
@online{CuArrays,
|
||||
author = {Mike Innes},
|
||||
title = {Generic GPU Kernels},
|
||||
year = 2017,
|
||||
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
@online{MLPL,
|
||||
author = {Mike Innes and others},
|
||||
title = {On Machine Learning and Programming Languages},
|
||||
|
@ -33,10 +25,26 @@
|
|||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
@online{Fusion,
|
||||
author = {Steven G. Johnson},
|
||||
title = {More Dots: Syntactic Loop Fusion in Julia},
|
||||
@online{CuArrays,
|
||||
author = {Mike Innes and others},
|
||||
title = {Generic GPU Kernels},
|
||||
year = 2017,
|
||||
url = {https://julialang.org/blog/2017/01/moredots},
|
||||
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
@online{Zoo,
|
||||
author = {Mike Innes and others},
|
||||
title = {Flux Model Zoo},
|
||||
year = 2018,
|
||||
url = {https://github.com/FluxML/model-zoo/},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
@online{Minibatch,
|
||||
author = {James Bradbury},
|
||||
title = {Minibatch.jl},
|
||||
year = 2018,
|
||||
url = {https://github.com/jekbradbury/Minibatch.jl},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
|
|
@ -24,8 +24,8 @@ bibliography: paper.bib
|
|||
|
||||
Flux is library for machine learning (ML), written using the numerical computing language Julia [@Julia]. The package allows models to be written using Julia's simple mathematical syntax, and applies automatic differentiation (AD) to seamlessly calculate derivatives and train the model. Meanwhile, it makes heavy use of Julia's language and compiler features to carry out code analysis and make optimisations. For example, Julia's GPU compilation support [@besard:2017] can be used to JIT-compile custom GPU kernels for model layers [@CuArrays].
|
||||
|
||||
The machine learning community has traditionally been divided between "static" and "dynamic" frameworks that are easy to optimise and easy to use, respectively [@MLPL]. Flux blurs the line between these two approaches, combining a highly intuitive programming model with the compiler techniques needed by ML. As a result of this approach, it already supports several features not available in any other dynamic framework, such as kernel fusion [@Fusion], memory usage optimisations, importing of models via ONNX, and deployment of models to JavaScript for running in the browser.
|
||||
The machine learning community has traditionally been divided between "static" and "dynamic" frameworks that are easy to optimise and easy to use, respectively [@MLPL]. Flux blurs the line between these two approaches, combining a highly intuitive programming model with the compiler techniques needed by ML. This enables research into advanced compiler transforms such as batching [@Minibatch] without changing any user code.
|
||||
|
||||
Flux has been used heavily for natural language processing, but can also support state-of-the-art research models in areas like computer vision, reinforcement learning and robotics.
|
||||
Flux has been used heavily for natural language processing, but can also support state-of-the-art research models in areas like computer vision, reinforcement learning and robotics. Many examples of such models can be found in the model zoo [@Zoo].
|
||||
|
||||
# References
|
||||
|
|
17
src/Flux.jl
17
src/Flux.jl
|
@ -7,22 +7,23 @@ module Flux
|
|||
using Juno, Requires, Reexport
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
SGD, ADAM, Momentum, Nesterov, AMSGrad, NADAM,
|
||||
param, params, mapleaves, cpu, gpu
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
Dropout, LayerNorm, BatchNorm,
|
||||
params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
using NNlib: @fix
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
export Tracker
|
||||
import .Tracker: data
|
||||
using .Tracker: data
|
||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
export SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
@ -32,12 +33,10 @@ include("layers/stateless.jl")
|
|||
include("layers/basic.jl")
|
||||
include("layers/conv.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/normalise.jl")
|
||||
|
||||
include("data/Data.jl")
|
||||
|
||||
include("jit/JIT.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
|
|
@ -4,11 +4,4 @@ using CuArrays
|
|||
|
||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||
|
||||
import ..Flux.JIT: Shape, restructure
|
||||
|
||||
function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T
|
||||
buf = buf[1:sizeof(sh)]
|
||||
reshape(reinterpret(T, buf), size(sh))
|
||||
end
|
||||
|
||||
end
|
||||
|
|
|
@ -10,13 +10,15 @@ const cache_prefix = "https://cache.julialang.org"
|
|||
function load()
|
||||
suffixes = ["", ".phones", ".symbols"]
|
||||
if isdir(deps("cmudict"))
|
||||
if all(isfile.(["cmudict$x" for x in suffixes]))
|
||||
if all(isfile(deps("cmudict", "cmudict$x")) for x in suffixes)
|
||||
return
|
||||
end
|
||||
end
|
||||
info("Downloading CMUDict dataset")
|
||||
mkpath(deps("cmudict"))
|
||||
for x in suffixes
|
||||
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", deps("cmudict", "cmudict$x"))
|
||||
download("$cache_prefix/http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
|
||||
deps("cmudict", "cmudict$x"))
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ function load()
|
|||
"t10k-images-idx3-ubyte",
|
||||
"t10k-labels-idx1-ubyte"]
|
||||
isfile(file) && continue
|
||||
info("Downloading MNIST dataset")
|
||||
download("https://cache.julialang.org/http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
||||
open(file, "w") do io
|
||||
write(io, GZip.open(read, "$file.gz"))
|
||||
|
|
|
@ -4,10 +4,10 @@ using ZipFile
|
|||
using ..Data: deps
|
||||
|
||||
function load()
|
||||
isfile(deps("sentiment.zip")) ||
|
||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"))
|
||||
return
|
||||
isfile(deps("sentiment.zip")) || return
|
||||
info("Downloading sentiment treebank dataset")
|
||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"))
|
||||
end
|
||||
|
||||
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
module JIT
|
||||
|
||||
using MacroTools
|
||||
|
||||
include("shapes.jl")
|
||||
include("trace.jl")
|
||||
include("lib.jl")
|
||||
|
||||
end
|
|
@ -1,40 +0,0 @@
|
|||
# Primitive definitions
|
||||
|
||||
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
|
||||
Shape{T}(size(A,1))
|
||||
|
||||
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
|
||||
Shape{T}(size(A,1),size(B,2))
|
||||
|
||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
||||
A_mul_B!(C, A, B)
|
||||
|
||||
shape(::typeof(broadcast), f, xs...) =
|
||||
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
||||
|
||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
||||
|
||||
shape(::typeof(reshape), x::Shape{T}, i...) where T =
|
||||
Shape{T}(Base._reshape_uncolon(x, i))
|
||||
|
||||
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
using ..Tracker: _conv, _maxpool
|
||||
|
||||
shape(::typeof(softmax), x) = x
|
||||
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
|
||||
|
||||
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
|
||||
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
|
||||
|
||||
inplace!(::typeof(_conv), y, x, w, stride, pad) =
|
||||
NNlib.conv!(y, x, w, stride = stride, pad = pad)
|
||||
|
||||
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
|
||||
Shape{T}(NNlib.pdims(size(x), k, pad, k))
|
||||
|
||||
inplace!(::typeof(_maxpool), y, x, k, pad) =
|
||||
NNlib.maxpool!(y, x, k, pad = pad)
|
|
@ -1,56 +0,0 @@
|
|||
using ..Tracker: TrackedArray
|
||||
|
||||
struct Shape{T,N}
|
||||
dims::NTuple{N,Int}
|
||||
end
|
||||
|
||||
VecShape{T} = Shape{T,1}
|
||||
MatShape{T} = Shape{T,2}
|
||||
|
||||
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
||||
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
|
||||
|
||||
Base.size(s::Shape) = s.dims
|
||||
Base.size(s::Shape, n) = s.dims[n]
|
||||
Base.ndims(s::Shape{T,N}) where {T,N} = N
|
||||
Base.length(s::Shape) = prod(s.dims)
|
||||
Base.eltype(s::Shape{T}) where T = T
|
||||
|
||||
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
|
||||
|
||||
function Base.show(io::IO, s::Shape{T}) where T
|
||||
print(io, "Shape{$T}(")
|
||||
join(io, s.dims, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
shape(x) = x
|
||||
shape(x::Shape) = x
|
||||
shape(x::Tuple) = shape.(x)
|
||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||
shape(x::TrackedArray) = shape(x.data)
|
||||
|
||||
bytes(s::Shape) = sizeof(s)
|
||||
bytes(x::Tuple) = sum(bytes.(x))
|
||||
|
||||
# Recover structure from byte buffers
|
||||
# Make sure to hold on to the parent buffer for the lifetime of the data.
|
||||
|
||||
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
|
||||
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
|
||||
reshape(reinterpret(T, buf), size(sh))
|
||||
end
|
||||
|
||||
# Execution with caches
|
||||
|
||||
mutable struct Cached{F,A}
|
||||
f::F
|
||||
buffer::A
|
||||
end
|
||||
|
||||
function (c::Cached)(args...)
|
||||
sh = shape(c.f, shape(args)...)
|
||||
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
|
||||
y = restructure(sh, c.buffer)
|
||||
inplace!(c.f, y, args...)
|
||||
end
|
|
@ -1,75 +0,0 @@
|
|||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||
|
||||
using ..Tracker, DataFlow
|
||||
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
||||
inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
|
||||
function graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(y -> x === y, inputs)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
||||
constant(x)
|
||||
end
|
||||
|
||||
function trace(f, args...)
|
||||
inputs = param.(args)
|
||||
graph(f(inputs...), inputs...)
|
||||
end
|
||||
|
||||
# Graph manipulation
|
||||
|
||||
function liftparams(v)
|
||||
ps = []
|
||||
v = prewalk(DataFlow.bumpinputs(v)) do v
|
||||
isconstant(v) && istracked(v.value.value) || return v
|
||||
push!(ps, v.value.value)
|
||||
DataFlow.vcall(getindex, inputnode(1), length(ps))
|
||||
end
|
||||
return v, ps
|
||||
end
|
||||
|
||||
function cacheall(v, buf = () -> UInt8[])
|
||||
prewalk(v) do v
|
||||
iscall(v) && isconstant(v[1]) || return v
|
||||
f = v[1].value.value
|
||||
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
|
||||
end
|
||||
end
|
||||
|
||||
code(v, n) = syntax(vertex(Lambda(n, v)))
|
||||
|
||||
struct Compiled{F,T<:Tuple}
|
||||
model
|
||||
func::F
|
||||
params::T
|
||||
end
|
||||
|
||||
# TODO when we support derivatives
|
||||
# (c::Compiled)(args...) =
|
||||
# Tracker.track(Tracker.Call(c, args...),
|
||||
# c.func(Tracker.data.(c.params), args...))
|
||||
|
||||
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
|
||||
|
||||
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
||||
|
||||
function compile(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
|
||||
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
||||
end
|
||||
|
||||
function source(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(v)
|
||||
code(v, length(args)+1) |> prettify
|
||||
end
|
|
@ -10,7 +10,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
|||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad` and `stride`.
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct Conv{N,F,A,V}
|
||||
σ::F
|
||||
|
@ -18,17 +18,19 @@ struct Conv{N,F,A,V}
|
|||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0) where T =
|
||||
Conv(σ, w, b, stride, pad)
|
||||
stride = 1, pad = 0, dilation=1) where T =
|
||||
Conv(σ, w, b, stride, pad, dilation)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
pad::NTuple{N,Integer} = map(_->0,k),
|
||||
dilation::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad)
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
Flux.treelike(Conv)
|
||||
|
||||
|
@ -36,7 +38,7 @@ function (c::Conv)(x)
|
|||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
|
||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Conv)
|
||||
|
@ -45,6 +47,3 @@ function Base.show(io::IO, l::Conv)
|
|||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
# v0.5
|
||||
@deprecate Conv2D(args...; kw...) Conv(args...; kw...)
|
||||
|
|
|
@ -31,15 +31,14 @@ function Dropout(p)
|
|||
Dropout{typeof(p)}(p, true)
|
||||
end
|
||||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
y = similar(x)
|
||||
rand!(y)
|
||||
q = 1 - a.p
|
||||
@inbounds for i=1:length(y)
|
||||
y[i] = y[i] > a.p ? 1 / q : 0
|
||||
end
|
||||
return y .* x
|
||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
return x .* y
|
||||
end
|
||||
|
||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||
|
@ -68,70 +67,88 @@ function Base.show(io::IO, l::LayerNorm)
|
|||
end
|
||||
|
||||
"""
|
||||
BatchNorm(dims...; λ = identity,
|
||||
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1)
|
||||
BatchNorm(channels::Integer, σ = identity;
|
||||
initβ = zeros, initγ = ones,
|
||||
ϵ = 1e-8, momentum = .1)
|
||||
|
||||
Batch Normalization Layer for [`Dense`](@ref) layer.
|
||||
Batch Normalization layer. The `channels` input should be the size of the
|
||||
channel dimension in your data (see below).
|
||||
|
||||
Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For
|
||||
a batch of feature vectors this is just the data dimension, for `WHCN` images
|
||||
it's the usual channel dimension.)
|
||||
|
||||
`BatchNorm` computes the mean and variance for each each `W×H×1×N` slice and
|
||||
shifts them to have a new mean and variance (corresponding to the learnable,
|
||||
per-channel `bias` and `scale` parameters).
|
||||
|
||||
See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf)
|
||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
||||
|
||||
In the example of MNIST,
|
||||
in order to normalize the input of other layer,
|
||||
put the `BatchNorm` layer before activation function.
|
||||
Example:
|
||||
|
||||
```julia
|
||||
m = Chain(
|
||||
Dense(28^2, 64),
|
||||
BatchNorm(64, λ = relu),
|
||||
BatchNorm(64, relu),
|
||||
Dense(64, 10),
|
||||
BatchNorm(10),
|
||||
softmax)
|
||||
```
|
||||
"""
|
||||
mutable struct BatchNorm{F,V,N}
|
||||
mutable struct BatchNorm{F,V,W,N}
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ # moving mean
|
||||
σ # moving std
|
||||
μ::W # moving mean
|
||||
σ::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
BatchNorm(dims::Integer...; λ = identity,
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
|
||||
BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true)
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
function (BN::BatchNorm)(x)
|
||||
λ, γ, β = BN.λ, BN.γ, BN.β
|
||||
size(x, ndims(x)-1) == length(BN.β) ||
|
||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
γ, β = BN.γ, BN.β
|
||||
dims = length(size(x))
|
||||
channels = size(x, dims-1)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = channels
|
||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||
|
||||
if !BN.active
|
||||
μ = BN.μ
|
||||
σ = BN.σ
|
||||
μ = reshape(BN.μ, affine_shape...)
|
||||
σ = reshape(BN.σ, affine_shape...)
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
m = size(x, 2) # batch size
|
||||
μ = mean(x, 2)
|
||||
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
|
||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||
μ = mean(x, axes)
|
||||
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ)
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1)
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
||||
end
|
||||
|
||||
λ.(γ .* ((x .- μ) ./ σ) .+ β)
|
||||
let λ = BN.λ
|
||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
|
||||
end
|
||||
end
|
||||
|
||||
children(BN::BatchNorm) =
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), BN.μ, BN.σ, BN.momentum, BN.ϵ, BN.active)
|
||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
||||
|
|
@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
|
|||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return @fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
module Optimise
|
||||
|
||||
export update!, params, train!,
|
||||
SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
export train!,
|
||||
SGD, ADAM, AdaMax, Momentum, Nesterov,
|
||||
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
|
|
|
@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) =
|
|||
ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||
|
||||
[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on
|
||||
the ∞-norm.
|
||||
"""
|
||||
AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||
optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||
|
||||
"""
|
||||
ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0)
|
||||
|
||||
|
|
|
@ -62,6 +62,18 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||
end
|
||||
end
|
||||
|
||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
ut = zeros(p.x)
|
||||
β1p = β1
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. ut = max(β2 * ut, abs(p.Δ))
|
||||
@. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ)
|
||||
β1p *= β1
|
||||
end
|
||||
end
|
||||
|
||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x) .+ ϵ
|
||||
|
|
|
@ -41,7 +41,7 @@ end
|
|||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
back!(::TrackedArray) = error("Use back!(x, Δ)")
|
||||
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
|
@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
|||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
|
||||
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
||||
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
|
||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||
Δ′ = similar(xs.data)
|
||||
S = size(xs.data)
|
||||
|
@ -108,20 +93,93 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
|||
back(xs, Δ′)
|
||||
end
|
||||
|
||||
|
||||
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
|
||||
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
|
||||
|
||||
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
|
||||
Δ′ = similar(xs.data)
|
||||
Δ′ .= 0
|
||||
S = size(xs.data)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
Δ′[src_idx...] += val
|
||||
end
|
||||
back(xs, Δ′)
|
||||
end
|
||||
|
||||
|
||||
for f in [:vcat, :hcat]
|
||||
@eval begin
|
||||
# This section is a bit of a hack since julia doesn't have a standardised
|
||||
# promotion mechanism for concatenation yet
|
||||
# https://github.com/JuliaLang/julia/pull/20815
|
||||
|
||||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||
# TrackedArray anywhere among the arguments This works as long as base has
|
||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# first
|
||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# second
|
||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
|
||||
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
|
||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
i = Base.tail(map(_ -> :, size(Δ)))
|
||||
start = 0
|
||||
for xsi in xs
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||
start += size(xsi, 1)
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||
track(reshape, xs, dims...)
|
||||
function back(::typeof(hcat), Δ, xs...)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
if ndims(xsi) == 1
|
||||
@back(xsi, Δ[:, start+1])
|
||||
else
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
|
||||
track(reshape, xs, dims)
|
||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||
|
||||
function back(::typeof(cat), Δ, dims, Xs...)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
for xs in Xs
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
|
||||
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
|
||||
|
||||
start = start .+ till_xs
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
|
||||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
|
@ -158,12 +216,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
|||
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
|
||||
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
|
||||
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
|
||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
||||
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
@ -186,6 +248,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./
|
|||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmax(xs.data)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmax(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmin(xs.data)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmin(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
|
@ -247,35 +334,35 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
||||
|
||||
# TODO: can store kwargs efficiently in namedtuples
|
||||
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
|
||||
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_conv, x, w, stride, pad)
|
||||
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_conv, x, w, stride, pad)
|
||||
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
|
||||
track(_conv, x, w, stride, pad)
|
||||
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
|
||||
function back(::typeof(_conv), Δ, x, w, stride, pad)
|
||||
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
|
||||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
|
||||
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
|
||||
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||
end
|
||||
|
||||
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
|
||||
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
||||
|
||||
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
|
||||
track(_maxpool, x, k, pad)
|
||||
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_maxpool, x, k, pad, stride)
|
||||
|
||||
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
|
||||
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
|
||||
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
|
||||
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
|
||||
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
|
||||
|
||||
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
|
||||
track(_meanpool, x, k, pad)
|
||||
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_meanpool, x, k, pad, stride)
|
||||
|
||||
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
|
||||
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
|
||||
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
|
||||
# Broadcasting
|
||||
|
||||
|
|
|
@ -19,8 +19,9 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
|||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
# This cuts derivatives, fix if needed.
|
||||
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
||||
|
@ -91,3 +92,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
|||
|
||||
back(::typeof(getindex), Δ, t, i) =
|
||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
xs = Base.collect(xs)
|
||||
track(Call(collect, xs), data.(xs))
|
||||
end
|
||||
|
||||
function scan(c::Call{typeof(collect)})
|
||||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back(::typeof(collect), Δ, xs)
|
||||
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
|
||||
end
|
||||
|
|
|
@ -21,6 +21,10 @@ cm = gpu(m)
|
|||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
|
||||
x = [1,2,3]
|
||||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
|
||||
# Fails in Pkg.test ffs
|
||||
# c = gpu(Conv((2,2),3=>4))
|
||||
# l = c(gpu(rand(10,10,3,2)))
|
||||
|
|
12
test/jit.jl
12
test/jit.jl
|
@ -1,12 +0,0 @@
|
|||
using Flux, Base.Test
|
||||
using Flux.JIT: compile
|
||||
|
||||
@testset "JIT" begin
|
||||
|
||||
m = Dense(10, 5)
|
||||
f = compile(m, rand(10))
|
||||
x = rand(10)
|
||||
|
||||
@test Tracker.data(m(x)) == f(x)
|
||||
|
||||
end
|
|
@ -67,7 +67,7 @@ end
|
|||
end
|
||||
|
||||
# with activation function
|
||||
let m = BatchNorm(2, λ = σ), x = param([1 2; 3 4; 5 6]')
|
||||
let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]')
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
|
@ -77,4 +77,22 @@ end
|
|||
x′ = m(x).data
|
||||
@test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179)
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,7 +3,7 @@ using Flux.Tracker
|
|||
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt([w′])
|
||||
|
|
|
@ -10,7 +10,6 @@ include("layers/normalisation.jl")
|
|||
include("layers/stateless.jl")
|
||||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
include("jit.jl")
|
||||
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
include("cuda/cuda.jl")
|
||||
|
|
136
test/tracker.jl
136
test/tracker.jl
|
@ -1,5 +1,5 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad
|
||||
using NNlib: conv
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
|
@ -29,17 +29,97 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@test gradtest(vcat, rand(5), rand(3))
|
||||
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||
function promotiontest(f, A, B, C)
|
||||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
r2 = f(A, param(B), C)
|
||||
if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat]
|
||||
r3 = f(A, B, param(C))
|
||||
else
|
||||
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
|
||||
r3 = r2
|
||||
end
|
||||
r4 = f(param(A), param(B), param(C))
|
||||
|
||||
@test !isa(r0, TrackedArray)
|
||||
@test all(isa.([r1,r2,r3,r4], TrackedArray))
|
||||
@test r1 == r2 == r3 == r4
|
||||
@test r0 == Flux.data(r4)
|
||||
end
|
||||
|
||||
@testset "concat" begin
|
||||
cat1(x...) = cat(1, x...)
|
||||
cat2(x...) = cat(2, x...)
|
||||
|
||||
@testset for vcatf in [vcat, cat1]
|
||||
@test gradtest(vcatf, rand(5), rand(3))
|
||||
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcatf, rand(5)', rand(5)')
|
||||
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
||||
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
||||
@test gradtest(vcatf, rand(5), rand(3,1))
|
||||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||
end
|
||||
|
||||
@testset for hcatf in [hcat, cat2]
|
||||
@test gradtest(hcatf, rand(5), rand(5))
|
||||
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
||||
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
|
||||
@test gradtest(hcatf, rand(5)', rand(1,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||
end
|
||||
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
@test gradtest(catf, rand(5))
|
||||
@test gradtest(catf, rand(5)')
|
||||
@test gradtest(catf, rand(2,5))
|
||||
@test gradtest(catf, rand(2,5,3))
|
||||
end
|
||||
|
||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
|
||||
@testset "cat($dim, ...)" for dim in 3:5
|
||||
catdim = (x...) -> cat(dim, x...)
|
||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||
end
|
||||
|
||||
@test !isa(vcat(rand(2)), TrackedArray)
|
||||
@test !isa(hcat(rand(2)), TrackedArray)
|
||||
@test !isa(cat(1,rand(2)), TrackedArray)
|
||||
|
||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "promotiontest" begin
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
|
||||
end
|
||||
|
||||
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
|
||||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
@test gradtest(kron,rand(5), rand(3))
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||
|
||||
@test gradtest(kron, rand(5), rand(3))
|
||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||
@test gradtest(kron,rand(5,1), rand(3,1))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
|
@ -55,6 +135,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "maximum" begin
|
||||
@test gradtest(maximum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "minimum" begin
|
||||
@test gradtest(minimum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
|
||||
|
@ -82,6 +182,21 @@ end
|
|||
|
||||
@test param(2)^2 == 4.0
|
||||
|
||||
@testset "reshape" begin
|
||||
x = reshape(param(rand(2,2,2)), 4, 2)
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (4,2)
|
||||
x = reshape(param([1]), (1,:))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (1,1)
|
||||
x = reshape(param(rand(2)), (2,:))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (2,1)
|
||||
x = reshape(param(rand(2,2)), (1,:,2))
|
||||
@test x isa TrackedArray
|
||||
@test size(x) == (1,2,2)
|
||||
end
|
||||
|
||||
@testset "Intermediates" begin
|
||||
x = param([1])
|
||||
l = sum((x .+ x).^2)
|
||||
|
@ -108,4 +223,13 @@ b = param(rand())
|
|||
Tracker.back!(b)
|
||||
@test Tracker.grad(b) == 1
|
||||
|
||||
@testset "collect" begin
|
||||
x, y = param(2), param(3)
|
||||
xy = Tracker.collect([x, y])
|
||||
@test xy isa TrackedArray{Float64}
|
||||
z = xy[1]*xy[2]
|
||||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue