Fix merge conflicts
This commit is contained in:
commit
2559e7b4e6
|
@ -15,5 +15,5 @@ matrix:
|
|||
allow_failures:
|
||||
- julia: nightly
|
||||
after_success:
|
||||
- julia -e 'using Pkg; Pkg.add("Documenter"); Pkg.add("NNlib")'
|
||||
- julia -e 'using Pkg; ps=Pkg.PackageSpec(name="Documenter", version="0.19"); Pkg.add(ps); Pkg.pin(ps); Pkg.add("NNlib")'
|
||||
- julia -e 'using Pkg; cd(Pkg.dir("Flux")); include(joinpath("docs", "make.jl"))'
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Flux is a library for machine learning. It comes "batteries-included" with many useful tools built in, but also lets you use the full power of the Julia language where you need it. We follow a few key principles:
|
||||
|
||||
* **Doing the obvious thing**. Flux has relatively few explicit APIs for features like regularisation or embeddings. Instead, writing down the mathematical form will work – and be fast.
|
||||
* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When it doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own.
|
||||
* **You could have written Flux**. All of it, from [LSTMs](https://github.com/FluxML/Flux.jl/blob/ec16a2c77dbf6ab8b92b0eecd11661be7a62feef/src/layers/recurrent.jl#L131) to [GPU kernels](https://github.com/JuliaGPU/CuArrays.jl), is straightforward Julia code. When in doubt, it’s well worth looking at [the source](https://github.com/FluxML/Flux.jl/). If you need something different, you can easily roll your own.
|
||||
* **Play nicely with others**. Flux works well with Julia libraries from [data frames](https://github.com/JuliaComputing/JuliaDB.jl) and [images](https://github.com/JuliaImages/Images.jl) to [differential equation solvers](https://github.com/JuliaDiffEq/DifferentialEquations.jl), so you can easily build complex data processing pipelines that integrate Flux models.
|
||||
|
||||
## Installation
|
||||
|
|
|
@ -100,16 +100,16 @@ minus(a, b) = a - b
|
|||
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: TrackedReal, track, @grad
|
||||
using Flux.Tracker: TrackedArray, track, @grad
|
||||
|
||||
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||||
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)
|
||||
```
|
||||
|
||||
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
|
||||
|
||||
```julia
|
||||
@grad function minus(a, b)
|
||||
return minus(data(a),data(b)), Δ -> (Δ, -Δ)
|
||||
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
|
||||
end
|
||||
```
|
||||
|
||||
|
@ -121,6 +121,19 @@ Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to
|
|||
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
|
||||
```
|
||||
|
||||
We can then calculate the first derivative of `minus` as follows:
|
||||
|
||||
```julia
|
||||
a = param([1,2,3])
|
||||
b = param([3,2,1])
|
||||
|
||||
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
|
||||
|
||||
Tracker.back!(c, 1)
|
||||
Tracker.grad(a) # [1.00, 1.00, 1.00]
|
||||
Tracker.grad(b) # [-1.00, -1.00, -1.00]
|
||||
```
|
||||
|
||||
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
|
||||
|
||||
```julia
|
||||
|
|
|
@ -10,14 +10,14 @@ using Flux.Tracker
|
|||
f(x) = 3x^2 + 2x + 1
|
||||
|
||||
# df/dx = 6x + 2
|
||||
f′(x) = Tracker.gradient(f, x)[1]
|
||||
df(x) = Tracker.gradient(f, x)[1]
|
||||
|
||||
f′(2) # 14.0 (tracked)
|
||||
df(2) # 14.0 (tracked)
|
||||
|
||||
# d²f/dx² = 6
|
||||
f′′(x) = Tracker.gradient(f′, x)[1]
|
||||
d2f(x) = Tracker.gradient(df, x)[1]
|
||||
|
||||
f′′(2) # 6.0 (tracked)
|
||||
d2f(2) # 6.0 (tracked)
|
||||
```
|
||||
|
||||
(We'll learn more about why these numbers show up as `(tracked)` below.)
|
||||
|
|
|
@ -2,7 +2,7 @@ module CUDA
|
|||
|
||||
using ..CuArrays
|
||||
|
||||
if CuArrays.cudnn_available()
|
||||
if CuArrays.libcudnn != nothing
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl")
|
||||
end
|
||||
|
|
|
@ -24,10 +24,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
|
|||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||
|
||||
function params(w::CuVector, input, hidden, n = 1)
|
||||
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
|
||||
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
|
||||
wx = slice(0, (input, hidden*n))
|
||||
wh = slice(length(wx), (hidden, hidden*n))
|
||||
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
|
||||
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
|
||||
(wx, wh), bias
|
||||
end
|
||||
|
||||
|
@ -306,7 +306,7 @@ end
|
|||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -320,7 +320,7 @@ end
|
|||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||
transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
|
|
@ -13,6 +13,9 @@ end
|
|||
include("mnist.jl")
|
||||
export MNIST
|
||||
|
||||
include("fashion-mnist.jl")
|
||||
export FashionMNIST
|
||||
|
||||
include("cmudict.jl")
|
||||
using .CMUDict
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
module FashionMNIST
|
||||
|
||||
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
|
||||
|
||||
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
|
||||
|
||||
function load()
|
||||
mkpath(dir)
|
||||
cd(dir) do
|
||||
for file in ["train-images-idx3-ubyte",
|
||||
"train-labels-idx1-ubyte",
|
||||
"t10k-images-idx3-ubyte",
|
||||
"t10k-labels-idx1-ubyte"]
|
||||
isfile(file) && continue
|
||||
@info "Downloading Fashion-MNIST dataset"
|
||||
download("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/$file.gz", "$file.gz")
|
||||
open(file, "w") do io
|
||||
write(io, gzopen(read, "$file.gz"))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
|
||||
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
|
||||
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
|
||||
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
|
||||
|
||||
"""
|
||||
images()
|
||||
images(:test)
|
||||
|
||||
Load the Fashion-MNIST images.
|
||||
|
||||
Each image is a 28×28 array of `Gray` colour values (see Colors.jl).
|
||||
|
||||
Returns the 60,000 training images by default; pass `:test` to retreive the
|
||||
10,000 test images.
|
||||
"""
|
||||
function images(set = :train)
|
||||
load()
|
||||
io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES))
|
||||
_, N, nrows, ncols = imageheader(io)
|
||||
[rawimage(io) for _ in 1:N]
|
||||
end
|
||||
|
||||
"""
|
||||
labels()
|
||||
labels(:test)
|
||||
|
||||
Load the labels corresponding to each of the images returned from `images()`.
|
||||
Each label is a number from 0-9.
|
||||
|
||||
Returns the 60,000 training labels by default; pass `:test` to retreive the
|
||||
10,000 test labels.
|
||||
"""
|
||||
function labels(set = :train)
|
||||
load()
|
||||
io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS))
|
||||
_, N = labelheader(io)
|
||||
[rawlabel(io) for _ = 1:N]
|
||||
end
|
||||
|
||||
end
|
|
@ -4,7 +4,7 @@ using ZipFile
|
|||
using ..Data: deps
|
||||
|
||||
function load()
|
||||
isfile(deps("sentiment.zip")) || return
|
||||
isfile(deps("sentiment.zip")) && return
|
||||
@info "Downloading sentiment treebank dataset"
|
||||
download("https://cache.julialang.org/https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"))
|
||||
|
@ -26,9 +26,10 @@ totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b))
|
|||
totree(t::Expr) = totree_(t.args...)
|
||||
|
||||
function parsetree(s)
|
||||
s = replace(s, r"\$", s -> "\\\$")
|
||||
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
||||
s = replace(s, " ", ", ")
|
||||
s = replace(s, "\\" => "")
|
||||
s = replace(s, "\$" => "\\\$")
|
||||
s = replace(s, r"[^ \n\(\)]+" => s -> "\"$s\"")
|
||||
s = replace(s, " " => ", ")
|
||||
return totree(Meta.parse(s))
|
||||
end
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ end
|
|||
|
||||
@treelike Dense
|
||||
|
||||
function (a::Dense)(x)
|
||||
function (a::Dense)(x::AbstractArray)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
σ.(W*x .+ b)
|
||||
end
|
||||
|
|
|
@ -148,7 +148,7 @@ Base.show(io::IO, l::LSTMCell) =
|
|||
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
|
||||
|
||||
"""
|
||||
LSTM(in::Integer, out::Integer, σ = tanh)
|
||||
LSTM(in::Integer, out::Integer)
|
||||
|
||||
Long Short Term Memory recurrent layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
|
@ -189,7 +189,7 @@ Base.show(io::IO, l::GRUCell) =
|
|||
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
|
||||
|
||||
"""
|
||||
GRU(in::Integer, out::Integer, σ = tanh)
|
||||
GRU(in::Integer, out::Integer)
|
||||
|
||||
Gated Recurrent Unit layer. Behaves like an RNN but generally
|
||||
exhibits a longer memory span over sequences.
|
||||
|
|
|
@ -108,10 +108,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs))
|
|||
param(x::TrackedReal) = track(identity, x)
|
||||
param(x::TrackedArray) = track(identity, x)
|
||||
|
||||
import NNlib.cudata
|
||||
import Adapt.adapt
|
||||
|
||||
cudata(x::TrackedArray) = data(x)
|
||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
||||
end
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import Base: *
|
||||
|
||||
import LinearAlgebra
|
||||
import LinearAlgebra: inv, \, /
|
||||
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||
|
||||
|
@ -41,6 +43,8 @@ end
|
|||
|
||||
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
||||
|
||||
Base.copy(x::TrackedArray) = x
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
|
@ -205,6 +209,41 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
||||
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||
|
||||
|
||||
inv(A::TrackedArray) = Tracker.track(inv, A)
|
||||
@grad function inv(A)
|
||||
return inv(Tracker.data(A)), function (Δ)
|
||||
Ainv = inv(A)
|
||||
∇A = - Ainv' * Δ * Ainv'
|
||||
return (∇A, )
|
||||
end
|
||||
end
|
||||
|
||||
# (/) rdivide
|
||||
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
||||
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
||||
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
||||
@grad function (A / B)
|
||||
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
||||
Binv = inv(B)
|
||||
∇B = - Binv' * A' * Δ * Binv'
|
||||
return (Δ * Binv', ∇B)
|
||||
end
|
||||
end
|
||||
|
||||
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
||||
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
||||
@grad function (A \ B)
|
||||
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
||||
Ainv = inv(A)
|
||||
∇A = - Ainv' * Δ * B' * Ainv'
|
||||
return (∇A, Ainv' * Δ)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
|
@ -337,8 +376,7 @@ unbroadcast(x::AbstractArray, Δ) =
|
|||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue{<:Function}, _) = nothing
|
||||
unbroadcast(x::Base.RefValue{<:Val}, _) = nothing
|
||||
unbroadcast(x::Base.RefValue, _) = nothing
|
||||
|
||||
dual(x, p) = x
|
||||
dual(x::Real, p) = Dual(x, p)
|
||||
|
@ -353,9 +391,9 @@ end
|
|||
eltype(y) <: Real || return y
|
||||
eltype(y) == Bool && return y
|
||||
function back(Δ)
|
||||
Δargs = ntuple(i -> partial.(f, data(Δ), i, args...), Val(N))
|
||||
dxs = unbroadcast.(args, Δargs)
|
||||
return nobacksies(:broadcast, dxs)
|
||||
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
|
||||
dxs = map(unbroadcast, args, Δargs)
|
||||
return dxs
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, tracker.(args)), y)
|
||||
|
|
|
@ -23,6 +23,8 @@ end
|
|||
|
||||
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
||||
|
||||
Base.copy(x::TrackedReal) = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
@ -58,12 +60,18 @@ for (M, f, arity) in DiffRules.diffrules()
|
|||
end
|
||||
end
|
||||
|
||||
# Work around zero(π) not working, for some reason
|
||||
_zero(::Irrational) = nothing
|
||||
_zero(x) = zero(x)
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 2 || continue
|
||||
da, db = DiffRules.diffrule(M, f, :a, :b)
|
||||
f = :($M.$f)
|
||||
@eval begin
|
||||
@grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
||||
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
||||
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, _zero(b))
|
||||
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db)
|
||||
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
||||
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
||||
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
||||
|
|
|
@ -54,7 +54,7 @@ function loadparams!(m, xs)
|
|||
for (p, x) in zip(params(m), xs)
|
||||
size(p) == size(x) ||
|
||||
error("Expected param size $(size(p)), got $(size(x))")
|
||||
copy!(data(p), data(x))
|
||||
copyto!(data(p), data(x))
|
||||
end
|
||||
end
|
||||
|
||||
|
|
19
src/utils.jl
19
src/utils.jl
|
@ -24,7 +24,7 @@ julia> chunk(1:10, 3)
|
|||
"""
|
||||
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
|
||||
|
||||
batchindex(xs, i) = (reverse(Base.tail(reverse(indices(xs))))..., i)
|
||||
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
|
||||
|
||||
"""
|
||||
frequencies(xs)
|
||||
|
@ -66,7 +66,7 @@ julia> batch([[1,2,3],[4,5,6]])
|
|||
function batch(xs)
|
||||
data = first(xs) isa AbstractArray ?
|
||||
similar(first(xs), size(first(xs))..., length(xs)) :
|
||||
Vector{eltype(xs)}(length(xs))
|
||||
Vector{eltype(xs)}(undef, length(xs))
|
||||
for (i, x) in enumerate(xs)
|
||||
data[batchindex(data, i)...] = x
|
||||
end
|
||||
|
@ -153,3 +153,18 @@ function jacobian(m,x)
|
|||
end
|
||||
J'
|
||||
end
|
||||
|
||||
"""
|
||||
@jit ...
|
||||
|
||||
The `@jit` annotation can be applied to any code, and the code will be compiled
|
||||
for performance.
|
||||
|
||||
@jit f(x) = @jit(x) + @jit(x)
|
||||
|
||||
Note that compilation happens regardless of the `@jit` macro, so it should only
|
||||
be used for aesthetic purposes, or by recovering Python users.
|
||||
"""
|
||||
macro jit(ex)
|
||||
esc(ex)
|
||||
end
|
||||
|
|
|
@ -1,42 +1,42 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux: gpu
|
||||
|
||||
# @info "Testing GPU Support"
|
||||
#
|
||||
# @testset "CuArrays" begin
|
||||
#
|
||||
# CuArrays.allowscalar(false)
|
||||
#
|
||||
# x = param(randn(5, 5))
|
||||
# cx = gpu(x)
|
||||
# @test cx isa TrackedArray && cx.data isa CuArray
|
||||
#
|
||||
# x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
# cx = gpu(x)
|
||||
# @test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
# @test (cx .+ 1) isa CuArray
|
||||
#
|
||||
# m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||
# cm = gpu(m)
|
||||
#
|
||||
# @test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
# @test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
#
|
||||
# x = [1,2,3]
|
||||
# cx = gpu(x)
|
||||
# @test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
#
|
||||
# xs = param(rand(5,5))
|
||||
# ys = Flux.onehotbatch(1:5,1:5)
|
||||
# @test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
#
|
||||
# c = gpu(Conv((2,2),3=>4))
|
||||
# l = c(gpu(rand(10,10,3,2)))
|
||||
# Flux.back!(sum(l))
|
||||
#
|
||||
# end
|
||||
@info "Testing GPU Support"
|
||||
|
||||
if CuArrays.cudnn_available()
|
||||
@testset "CuArrays" begin
|
||||
|
||||
CuArrays.allowscalar(false)
|
||||
|
||||
x = param(randn(5, 5))
|
||||
cx = gpu(x)
|
||||
@test cx isa TrackedArray && cx.data isa CuArray
|
||||
|
||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = gpu(x)
|
||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
@test (cx .+ 1) isa CuArray
|
||||
|
||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||
cm = gpu(m)
|
||||
|
||||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
|
||||
x = [1,2,3]
|
||||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
|
||||
xs = param(rand(5,5))
|
||||
ys = Flux.onehotbatch(1:5,1:5)
|
||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
|
||||
c = gpu(Conv((2,2),3=>4))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
Flux.back!(sum(l))
|
||||
|
||||
end
|
||||
|
||||
if CuArrays.libcudnn != nothing
|
||||
@info "Testing Flux/CUDNN BatchNorm"
|
||||
include("cudnn.jl")
|
||||
@info "Testing Flux/CUDNN RNN"
|
||||
|
|
|
@ -9,3 +9,8 @@ using Test
|
|||
|
||||
@test MNIST.images()[1] isa Matrix
|
||||
@test MNIST.labels() isa Vector{Int64}
|
||||
|
||||
@test FashionMNIST.images()[1] isa Matrix
|
||||
@test FashionMNIST.labels() isa Vector{Int64}
|
||||
|
||||
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
using Test, Random
|
||||
|
||||
@testset "basic" begin
|
||||
@testset "Chain" begin
|
||||
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
|
||||
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
|
||||
# numeric test should be put into testset of corresponding layer
|
||||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
|
||||
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(1, 1)
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
|
||||
end
|
||||
|
||||
@testset "Diagonal" begin
|
||||
@test length(Flux.Diagonal(10)(randn(10))) == 10
|
||||
@test length(Flux.Diagonal(10)(1)) == 10
|
||||
@test length(Flux.Diagonal(10)(randn(1))) == 10
|
||||
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
|
||||
|
||||
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
|
||||
@test Flux.Diagonal(2)([1,2]) == [1,2]
|
||||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
|
||||
end
|
||||
end
|
|
@ -32,6 +32,7 @@ include("data.jl")
|
|||
|
||||
@info "Testing Layers"
|
||||
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
|
|
|
@ -129,6 +129,11 @@ end
|
|||
|
||||
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||
|
||||
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
||||
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
|
|
Loading…
Reference in New Issue