Fix merge conflicts
This commit is contained in:
commit
2559e7b4e6
@ -15,5 +15,5 @@ matrix:
|
|||||||
allow_failures:
|
allow_failures:
|
||||||
- julia: nightly
|
- julia: nightly
|
||||||
after_success:
|
after_success:
|
||||||
- julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")'
|
- julia -e 'using Pkg; ps=Pkg.PackageSpec(name="Documenter", version="0.19"); Pkg.add(ps); Pkg.pin(ps); Pkg.add("NNlib")'
|
||||||
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles:
|
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles:
|
||||||
|
|
||||||
* **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work – and be fast.
|
* **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work – and be fast.
|
||||||
* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When it doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own.
|
* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When in doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own.
|
||||||
* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models.
|
* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
@ -100,16 +100,16 @@ minus(a, b) = a - b
|
|||||||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker: TrackedReal, track, @grad
|
using Flux.Tracker: TrackedArray, track, @grad
|
||||||
|
|
||||||
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
|
||||||
```
|
```
|
||||||
|
|
||||||
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
@grad function minus(a, b)
|
@grad function minus(a, b)
|
||||||
return minus(data(a),data(b)), Δ -> (Δ, -Δ)
|
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
|
||||||
end
|
end
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to
|
|||||||
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
|
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We can then calculate the first derivative of `minus` as follows:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
a = param([1,2,3])
|
||||||
|
b = param([3,2,1])
|
||||||
|
|
||||||
|
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
|
||||||
|
|
||||||
|
Tracker.back!(c, 1)
|
||||||
|
Tracker.grad(a) # [1.00, 1.00, 1.00]
|
||||||
|
Tracker.grad(b) # [-1.00, -1.00, -1.00]
|
||||||
|
```
|
||||||
|
|
||||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
|
@ -10,14 +10,14 @@ using Flux.Tracker
|
|||||||
f(x) = 3x^2 + 2x + 1
|
f(x) = 3x^2 + 2x + 1
|
||||||
|
|
||||||
# df/dx = 6x + 2
|
# df/dx = 6x + 2
|
||||||
f′(x) = Tracker.gradient(f, x)[1]
|
df(x) = Tracker.gradient(f, x)[1]
|
||||||
|
|
||||||
f′(2) # 14.0 (tracked)
|
df(2) # 14.0 (tracked)
|
||||||
|
|
||||||
# d²f/dx² = 6
|
# d²f/dx² = 6
|
||||||
f′′(x) = Tracker.gradient(f′, x)[1]
|
d2f(x) = Tracker.gradient(df, x)[1]
|
||||||
|
|
||||||
f′′(2) # 6.0 (tracked)
|
d2f(2) # 6.0 (tracked)
|
||||||
```
|
```
|
||||||
|
|
||||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
||||||
|
@ -2,7 +2,7 @@ module CUDA
|
|||||||
|
|
||||||
using ..CuArrays
|
using ..CuArrays
|
||||||
|
|
||||||
if CuArrays.cudnn_available()
|
if CuArrays.libcudnn != nothing
|
||||||
include("curnn.jl")
|
include("curnn.jl")
|
||||||
include("cudnn.jl")
|
include("cudnn.jl")
|
||||||
end
|
end
|
||||||
|
@ -24,10 +24,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
|
|||||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||||
|
|
||||||
function params(w::CuVector, input, hidden, n = 1)
|
function params(w::CuVector, input, hidden, n = 1)
|
||||||
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
|
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
|
||||||
wx = slice(0, (input, hidden*n))
|
wx = slice(0, (input, hidden*n))
|
||||||
wh = slice(length(wx), (hidden, hidden*n))
|
wh = slice(length(wx), (hidden, hidden*n))
|
||||||
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
|
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
|
||||||
(wx, wh), bias
|
(wx, wh), bias
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -306,7 +306,7 @@ end
|
|||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, data(h))
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -320,7 +320,7 @@ end
|
|||||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN,
|
nobacksies(:RNN,
|
||||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||||
transpose(dWi), transpose(dWh), db))
|
transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -13,6 +13,9 @@ end
|
|||||||
include("mnist.jl")
|
include("mnist.jl")
|
||||||
export MNIST
|
export MNIST
|
||||||
|
|
||||||
|
include("fashion-mnist.jl")
|
||||||
|
export FashionMNIST
|
||||||
|
|
||||||
include("cmudict.jl")
|
include("cmudict.jl")
|
||||||
using .CMUDict
|
using .CMUDict
|
||||||
|
|
||||||
|
64
src/data/fashion-mnist.jl
Normal file
64
src/data/fashion-mnist.jl
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
module FashionMNIST
|
||||||
|
|
||||||
|
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
|
||||||
|
|
||||||
|
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
|
||||||
|
|
||||||
|
function load()
|
||||||
|
mkpath(dir)
|
||||||
|
cd(dir) do
|
||||||
|
for file in ["train-images-idx3-ubyte",
|
||||||
|
"train-labels-idx1-ubyte",
|
||||||
|
"t10k-images-idx3-ubyte",
|
||||||
|
"t10k-labels-idx1-ubyte"]
|
||||||
|
isfile(file) && continue
|
||||||
|
@info "Downloading Fashion-MNIST dataset"
|
||||||
|
download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz")
|
||||||
|
open(file, "w") do io
|
||||||
|
write(io, gzopen(read, "$file.gz"))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
|
||||||
|
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
|
||||||
|
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
|
||||||
|
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
|
||||||
|
|
||||||
|
"""
|
||||||
|
images()
|
||||||
|
images(:test)
|
||||||
|
|
||||||
|
Load the Fashion-MNIST images.
|
||||||
|
|
||||||
|
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
|
||||||
|
|
||||||
|
Returns the 60,000 training images by default; pass `:test` to retreive the
|
||||||
|
10,000 test images.
|
||||||
|
"""
|
||||||
|
function images(set = :train)
|
||||||
|
load()
|
||||||
|
io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES))
|
||||||
|
_, N, nrows, ncols = imageheader(io)
|
||||||
|
[rawimage(io) for _ in 1:N]
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
labels()
|
||||||
|
labels(:test)
|
||||||
|
|
||||||
|
Load the labels corresponding to each of the images returned from `images()`.
|
||||||
|
Each label is a number from 0-9.
|
||||||
|
|
||||||
|
Returns the 60,000 training labels by default; pass `:test` to retreive the
|
||||||
|
10,000 test labels.
|
||||||
|
"""
|
||||||
|
function labels(set = :train)
|
||||||
|
load()
|
||||||
|
io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS))
|
||||||
|
_, N = labelheader(io)
|
||||||
|
[rawlabel(io) for _ = 1:N]
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
@ -4,7 +4,7 @@ using ZipFile
|
|||||||
using ..Data: deps
|
using ..Data: deps
|
||||||
|
|
||||||
function load()
|
function load()
|
||||||
isfile(deps("sentiment.zip")) || return
|
isfile(deps("sentiment.zip")) && return
|
||||||
@info "Downloading sentiment treebank dataset"
|
@info "Downloading sentiment treebank dataset"
|
||||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||||
deps("sentiment.zip"))
|
deps("sentiment.zip"))
|
||||||
@ -26,9 +26,10 @@ totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b))
|
|||||||
totree(t::Expr) = totree_(t.args...)
|
totree(t::Expr) = totree_(t.args...)
|
||||||
|
|
||||||
function parsetree(s)
|
function parsetree(s)
|
||||||
s = replace(s, r"\$", s -> "\\\$")
|
s = replace(s, "\\" => "")
|
||||||
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
s = replace(s, "\$" => "\\\$")
|
||||||
s = replace(s, " ", ", ")
|
s = replace(s, r"[^ \n\(\)]+" => s -> "\"$s\"")
|
||||||
|
s = replace(s, " " => ", ")
|
||||||
return totree(Meta.parse(s))
|
return totree(Meta.parse(s))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ end
|
|||||||
|
|
||||||
@treelike Dense
|
@treelike Dense
|
||||||
|
|
||||||
function (a::Dense)(x)
|
function (a::Dense)(x::AbstractArray)
|
||||||
W, b, σ = a.W, a.b, a.σ
|
W, b, σ = a.W, a.b, a.σ
|
||||||
σ.(W*x .+ b)
|
σ.(W*x .+ b)
|
||||||
end
|
end
|
||||||
|
@ -148,7 +148,7 @@ Base.show(io::IO, l::LSTMCell) =
|
|||||||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
LSTM(in::Integer, out::Integer, σ = tanh)
|
LSTM(in::Integer, out::Integer)
|
||||||
|
|
||||||
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
||||||
exhibits a longer memory span over sequences.
|
exhibits a longer memory span over sequences.
|
||||||
@ -189,7 +189,7 @@ Base.show(io::IO, l::GRUCell) =
|
|||||||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
GRU(in::Integer, out::Integer, σ = tanh)
|
GRU(in::Integer, out::Integer)
|
||||||
|
|
||||||
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||||
exhibits a longer memory span over sequences.
|
exhibits a longer memory span over sequences.
|
||||||
|
@ -108,10 +108,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs))
|
|||||||
param(x::TrackedReal) = track(identity, x)
|
param(x::TrackedReal) = track(identity, x)
|
||||||
param(x::TrackedArray) = track(identity, x)
|
param(x::TrackedArray) = track(identity, x)
|
||||||
|
|
||||||
import NNlib.cudata
|
|
||||||
import Adapt.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
cudata(x::TrackedArray) = data(x)
|
|
||||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import Base: *
|
import Base: *
|
||||||
|
|
||||||
import LinearAlgebra
|
import LinearAlgebra
|
||||||
|
import LinearAlgebra: inv, \, /
|
||||||
|
|
||||||
using Statistics
|
using Statistics
|
||||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||||
|
|
||||||
@ -41,6 +43,8 @@ end
|
|||||||
|
|
||||||
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
||||||
|
|
||||||
|
Base.copy(x::TrackedArray) = x
|
||||||
|
|
||||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||||
error("Can't differentiate `setindex!`")
|
error("Can't differentiate `setindex!`")
|
||||||
|
|
||||||
@ -205,6 +209,41 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||||||
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
||||||
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
inv(A::TrackedArray) = Tracker.track(inv, A)
|
||||||
|
@grad function inv(A)
|
||||||
|
return inv(Tracker.data(A)), function (Δ)
|
||||||
|
Ainv = inv(A)
|
||||||
|
∇A = - Ainv' * Δ * Ainv'
|
||||||
|
return (∇A, )
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# (/) rdivide
|
||||||
|
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
||||||
|
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
||||||
|
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
||||||
|
@grad function (A / B)
|
||||||
|
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
||||||
|
Binv = inv(B)
|
||||||
|
∇B = - Binv' * A' * Δ * Binv'
|
||||||
|
return (Δ * Binv', ∇B)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
||||||
|
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||||
|
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||||
|
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
||||||
|
@grad function (A \ B)
|
||||||
|
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
||||||
|
Ainv = inv(A)
|
||||||
|
∇A = - Ainv' * Δ * B' * Ainv'
|
||||||
|
return (∇A, Ainv' * Δ)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
# Reductions
|
# Reductions
|
||||||
|
|
||||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||||
@ -337,8 +376,7 @@ unbroadcast(x::AbstractArray, Δ) =
|
|||||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||||
|
|
||||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||||
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
|
unbroadcast(x::Base.RefValue, _) = nothing
|
||||||
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
|
|
||||||
|
|
||||||
dual(x, p) = x
|
dual(x, p) = x
|
||||||
dual(x::Real, p) = Dual(x, p)
|
dual(x::Real, p) = Dual(x, p)
|
||||||
@ -353,9 +391,9 @@ end
|
|||||||
eltype(y) <: Real || return y
|
eltype(y) <: Real || return y
|
||||||
eltype(y) == Bool && return y
|
eltype(y) == Bool && return y
|
||||||
function back(Δ)
|
function back(Δ)
|
||||||
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
|
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
|
||||||
dxs = unbroadcast.(args, Δargs)
|
dxs = map(unbroadcast, args, Δargs)
|
||||||
return nobacksies(:broadcast, dxs)
|
return dxs
|
||||||
end
|
end
|
||||||
# So we can return non-tracked arrays
|
# So we can return non-tracked arrays
|
||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
|
@ -23,6 +23,8 @@ end
|
|||||||
|
|
||||||
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
||||||
|
|
||||||
|
Base.copy(x::TrackedReal) = x
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||||
@ -58,12 +60,18 @@ for (M, f, arity) in DiffRules.diffrules()
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Work around zero(π) not working, for some reason
|
||||||
|
_zero(::Irrational) = nothing
|
||||||
|
_zero(x) = zero(x)
|
||||||
|
|
||||||
for (M, f, arity) in DiffRules.diffrules()
|
for (M, f, arity) in DiffRules.diffrules()
|
||||||
arity == 2 || continue
|
arity == 2 || continue
|
||||||
da, db = DiffRules.diffrule(M, f, :a, :b)
|
da, db = DiffRules.diffrule(M, f, :a, :b)
|
||||||
f = :($M.$f)
|
f = :($M.$f)
|
||||||
@eval begin
|
@eval begin
|
||||||
@grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
||||||
|
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
|
||||||
|
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
|
||||||
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
||||||
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
||||||
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
||||||
|
@ -54,7 +54,7 @@ function loadparams!(m, xs)
|
|||||||
for (p, x) in zip(params(m), xs)
|
for (p, x) in zip(params(m), xs)
|
||||||
size(p) == size(x) ||
|
size(p) == size(x) ||
|
||||||
error("Expected param size $(size(p)), got $(size(x))")
|
error("Expected param size $(size(p)), got $(size(x))")
|
||||||
copy!(data(p), data(x))
|
copyto!(data(p), data(x))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
19
src/utils.jl
19
src/utils.jl
@ -24,7 +24,7 @@ julia> chunk(1:10, 3)
|
|||||||
"""
|
"""
|
||||||
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
|
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
|
||||||
|
|
||||||
batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i)
|
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
frequencies(xs)
|
frequencies(xs)
|
||||||
@ -66,7 +66,7 @@ julia> batch([[1,2,3],[4,5,6]])
|
|||||||
function batch(xs)
|
function batch(xs)
|
||||||
data = first(xs) isa AbstractArray ?
|
data = first(xs) isa AbstractArray ?
|
||||||
similar(first(xs), size(first(xs))..., length(xs)) :
|
similar(first(xs), size(first(xs))..., length(xs)) :
|
||||||
Vector{eltype(xs)}(length(xs))
|
Vector{eltype(xs)}(undef, length(xs))
|
||||||
for (i, x) in enumerate(xs)
|
for (i, x) in enumerate(xs)
|
||||||
data[batchindex(data, i)...] = x
|
data[batchindex(data, i)...] = x
|
||||||
end
|
end
|
||||||
@ -153,3 +153,18 @@ function jacobian(m,x)
|
|||||||
end
|
end
|
||||||
J'
|
J'
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
@jit ...
|
||||||
|
|
||||||
|
The `@jit` annotation can be applied to any code, and the code will be compiled
|
||||||
|
for performance.
|
||||||
|
|
||||||
|
@jit f(x) = @jit(x) + @jit(x)
|
||||||
|
|
||||||
|
Note that compilation happens regardless of the `@jit` macro, so it should only
|
||||||
|
be used for aesthetic purposes, or by recovering Python users.
|
||||||
|
"""
|
||||||
|
macro jit(ex)
|
||||||
|
esc(ex)
|
||||||
|
end
|
||||||
|
@ -1,42 +1,42 @@
|
|||||||
using Flux, Flux.Tracker, CuArrays, Test
|
using Flux, Flux.Tracker, CuArrays, Test
|
||||||
using Flux: gpu
|
using Flux: gpu
|
||||||
|
|
||||||
# @info "Testing GPU Support"
|
@info "Testing GPU Support"
|
||||||
#
|
|
||||||
# @testset "CuArrays" begin
|
|
||||||
#
|
|
||||||
# CuArrays.allowscalar(false)
|
|
||||||
#
|
|
||||||
# x = param(randn(5, 5))
|
|
||||||
# cx = gpu(x)
|
|
||||||
# @test cx isa TrackedArray && cx.data isa CuArray
|
|
||||||
#
|
|
||||||
# x = Flux.onehotbatch([1, 2, 3], 1:3)
|
|
||||||
# cx = gpu(x)
|
|
||||||
# @test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
|
||||||
# @test (cx .+ 1) isa CuArray
|
|
||||||
#
|
|
||||||
# m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
|
||||||
# cm = gpu(m)
|
|
||||||
#
|
|
||||||
# @test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
|
||||||
# @test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
|
||||||
#
|
|
||||||
# x = [1,2,3]
|
|
||||||
# cx = gpu(x)
|
|
||||||
# @test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
|
||||||
#
|
|
||||||
# xs = param(rand(5,5))
|
|
||||||
# ys = Flux.onehotbatch(1:5,1:5)
|
|
||||||
# @test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
|
||||||
#
|
|
||||||
# c = gpu(Conv((2,2),3=>4))
|
|
||||||
# l = c(gpu(rand(10,10,3,2)))
|
|
||||||
# Flux.back!(sum(l))
|
|
||||||
#
|
|
||||||
# end
|
|
||||||
|
|
||||||
if CuArrays.cudnn_available()
|
@testset "CuArrays" begin
|
||||||
|
|
||||||
|
CuArrays.allowscalar(false)
|
||||||
|
|
||||||
|
x = param(randn(5, 5))
|
||||||
|
cx = gpu(x)
|
||||||
|
@test cx isa TrackedArray && cx.data isa CuArray
|
||||||
|
|
||||||
|
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||||
|
cx = gpu(x)
|
||||||
|
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||||
|
@test (cx .+ 1) isa CuArray
|
||||||
|
|
||||||
|
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||||
|
cm = gpu(m)
|
||||||
|
|
||||||
|
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||||
|
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||||
|
|
||||||
|
x = [1,2,3]
|
||||||
|
cx = gpu(x)
|
||||||
|
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||||
|
|
||||||
|
xs = param(rand(5,5))
|
||||||
|
ys = Flux.onehotbatch(1:5,1:5)
|
||||||
|
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||||
|
|
||||||
|
c = gpu(Conv((2,2),3=>4))
|
||||||
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
|
Flux.back!(sum(l))
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
if CuArrays.libcudnn != nothing
|
||||||
@info "Testing Flux/CUDNN BatchNorm"
|
@info "Testing Flux/CUDNN BatchNorm"
|
||||||
include("cudnn.jl")
|
include("cudnn.jl")
|
||||||
@info "Testing Flux/CUDNN RNN"
|
@info "Testing Flux/CUDNN RNN"
|
||||||
|
@ -9,3 +9,8 @@ using Test
|
|||||||
|
|
||||||
@test MNIST.images()[1] isa Matrix
|
@test MNIST.images()[1] isa Matrix
|
||||||
@test MNIST.labels() isa Vector{Int64}
|
@test MNIST.labels() isa Vector{Int64}
|
||||||
|
|
||||||
|
@test FashionMNIST.images()[1] isa Matrix
|
||||||
|
@test FashionMNIST.labels() isa Vector{Int64}
|
||||||
|
|
||||||
|
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
|
||||||
|
33
test/layers/basic.jl
Normal file
33
test/layers/basic.jl
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
using Test, Random
|
||||||
|
|
||||||
|
@testset "basic" begin
|
||||||
|
@testset "Chain" begin
|
||||||
|
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||||
|
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||||
|
# numeric test should be put into testset of corresponding layer
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Dense" begin
|
||||||
|
@test length(Dense(10, 5)(randn(10))) == 5
|
||||||
|
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||||
|
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||||
|
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
|
||||||
|
|
||||||
|
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
|
||||||
|
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||||
|
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||||
|
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Diagonal" begin
|
||||||
|
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||||
|
@test length(Flux.Diagonal(10)(1)) == 10
|
||||||
|
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||||
|
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||||
|
|
||||||
|
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
|
||||||
|
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||||
|
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||||
|
end
|
||||||
|
end
|
@ -32,6 +32,7 @@ include("data.jl")
|
|||||||
|
|
||||||
@info "Testing Layers"
|
@info "Testing Layers"
|
||||||
|
|
||||||
|
include("layers/basic.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/conv.jl")
|
include("layers/conv.jl")
|
||||||
|
@ -129,6 +129,11 @@ end
|
|||||||
|
|
||||||
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||||
|
|
||||||
|
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||||
|
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||||
|
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
||||||
|
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
||||||
|
|
||||||
@testset "mean" begin
|
@testset "mean" begin
|
||||||
@test gradtest(mean, rand(2, 3))
|
@test gradtest(mean, rand(2, 3))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user