Merge branch 'master' into zygote

This commit is contained in:
Mike J Innes 2019-09-06 15:18:58 +01:00
commit 67c38b3099
16 changed files with 140 additions and 55 deletions

View File

@ -33,12 +33,35 @@ git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.6" version = "0.5.6"
[[CEnum]]
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.2.0"
[[CSTParser]] [[CSTParser]]
deps = ["Tokenize"] deps = ["Tokenize"]
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b" git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.6.2" version = "0.6.2"
[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "1.1.0"
[[CUDAdrv]]
deps = ["CUDAapi", "Libdl", "Printf"]
git-tree-sha1 = "9ce99b5732c70e06ed97c042187baed876fb1698"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "3.1.0"
[[CUDAnative]]
deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"]
git-tree-sha1 = "0a00bef482b7c9127495c7f4a2a85e73b13b5af8"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "2.3.0"
[[CodecZlib]] [[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"] deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e" git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e"
@ -81,6 +104,12 @@ git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0" version = "4.0.0"
[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
git-tree-sha1 = "46b48742a84bb839e74215b7e468a4a1c6ba30f9"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "1.2.1"
[[DataAPI]] [[DataAPI]]
git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0" git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
@ -124,9 +153,9 @@ version = "0.3.0"
[[FillArrays]] [[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"] deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "4c707c87ddd3199fc5624d5c98b2c706e4d00675" git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.7.0" version = "0.6.4"
[[FixedPointNumbers]] [[FixedPointNumbers]]
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b" git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
@ -139,6 +168,12 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210" uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3" version = "0.10.3"
[[GPUArrays]]
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
git-tree-sha1 = "dd169c636d1d3656a9faca772f5bd7c226a61254"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "1.0.1"
[[IRTools]] [[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"] deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "e23faa71b8f54c3fdc99b230b9c2906cafdddca5" git-tree-sha1 = "e23faa71b8f54c3fdc99b230b9c2906cafdddca5"
@ -161,6 +196,12 @@ git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.7.2" version = "0.7.2"
[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "52cfea426bd248a427aace7d88eb5d45b84ea297"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.2.0"
[[LibGit2]] [[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

View File

@ -1,6 +1,7 @@
# v0.9.0 # v0.9.0
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor. * [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures. * New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
* New [RADAM](https://github.com/FluxML/Flux.jl/pull/842) optimiser.
# v0.8.0 # v0.8.0

View File

@ -1,12 +1,14 @@
name = "Flux" name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.8.3" version = "0.9.0"
[deps] [deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@ -16,7 +18,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@ -25,6 +26,8 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat] [compat]
CUDAapi = "1.1"
CuArrays = "1.2"
NNlib = "0.6" NNlib = "0.6"
Zygote = "0.3" Zygote = "0.3"
julia = "1.1" julia = "1.1"

View File

@ -14,8 +14,8 @@ Which means allocations occur much faster.
And you use less memory. And you use less memory.
## Make sure your custom activation functions preserve the type of their inputs ## Make sure your activation and loss functions preserve the type of their inputs
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1), Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
they should also preserve the type of their inputs. they should also preserve the type of their inputs.
A very artificial example using an activation function like A very artificial example using an activation function like
@ -26,6 +26,7 @@ A very artificial example using an activation function like
will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would, will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
because it results in having to use slow mixed type multiplication in the dense layers. because it results in having to use slow mixed type multiplication in the dense layers.
Similar situations can occur in the loss function during backpropagation.
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above), Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
you will see a large slow-down you will see a large slow-down
@ -60,7 +61,7 @@ end
It is much faster to concatenate them into a matrix, It is much faster to concatenate them into a matrix,
as this will hit BLAS matrix-matrix multiplication, which is much faster than the equivalent sequence of matrix-vector multiplications. as this will hit BLAS matrix-matrix multiplication, which is much faster than the equivalent sequence of matrix-vector multiplications.
Even though this means allocating new memory to store them contiguously. The improvement is enough that it is worthwhile allocating new memory to store them contiguously.
```julia ```julia
x_batch = reduce(hcat, xs) x_batch = reduce(hcat, xs)

View File

@ -3,7 +3,7 @@ module Flux
# Zero Flux Given # Zero Flux Given
using Base: tail using Base: tail
using Zygote, MacroTools, Juno, Requires, Reexport, Statistics, Random using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib @reexport using NNlib
using Zygote: Params, @adjoint, gradient, forward using Zygote: Params, @adjoint, gradient, forward
@ -18,7 +18,20 @@ using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, InvDecay, ExpDecay, WeightDecay ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
using CUDAapi
if has_cuda()
try
using CuArrays
@eval has_cuarrays() = true
catch ex
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())
@eval has_cuarrays() = false
end
else
has_cuarrays() = false
end
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")
@ -34,6 +47,8 @@ include("data/Data.jl")
include("deprecations.jl") include("deprecations.jl")
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl") if has_cuarrays()
include("cuda/cuda.jl")
end
end # module end # module

View File

@ -1,38 +1,12 @@
module CUDA module CUDA
using ..CuArrays using ..CuArrays
import ..CuArrays.CUDAdrv: CuPtr, CU_NULL
using Pkg.TOML
function version_check() if CuArrays.libcudnn !== nothing # TODO: use CuArrays.has_cudnn()
major_version = 1
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
project = TOML.parse(String(read(project)))
version = VersionNumber(get(project, "version", "0.0.0"))
if version.major != major_version
@warn """
Flux is only supported with CuArrays v$major_version.x.
Try running `] pin CuArrays@$major_version`.
"""
end
end
version_check()
if !applicable(CuArray{UInt8}, undef, 1)
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
end
if CuArrays.libcudnn != nothing
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
else
handle() = CuArrays.CUDNN.handle()
end
include("curnn.jl") include("curnn.jl")
include("cudnn.jl") include("cudnn.jl")
else else
@warn("CUDNN is not installed, some functionality will not be available.") @warn "CUDNN is not installed, some functionality will not be available."
end end
end end

View File

@ -1,5 +1,8 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t, using CuArrays: libcudnn
using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
import CuArrays.CUDAdrv: CuPtr, CU_NULL
using LinearAlgebra using LinearAlgebra
mutable struct DropoutDesc mutable struct DropoutDesc

View File

@ -1,5 +1,9 @@
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t, using CuArrays: libcudnn
using CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
import CuArrays.CUDAdrv: CuPtr, CU_NULL
using LinearAlgebra using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation const RNN_RELU = 0 # Stock RNN with ReLu activation
@ -63,7 +67,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint), @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
w = cuzeros(T, rnnParamSize(T, d[], input)) w = CuArrays.zeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here # TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd) do x finalizer(rd) do x
@ -130,8 +134,8 @@ end
# TODO: can we just manipulate strides here? # TODO: can we just manipulate strides here?
# TODO: should use repmat, but this isn't implemented. # TODO: should use repmat, but this isn't implemented.
hBatch(x::AbstractVector, h::CuVector) = h hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2)) hBatch(x::AbstractMatrix, h::CuVector) = h .* CuArrays.ones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1) hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
h = hBatch(x, h_) h = hBatch(x, h_)
@ -221,8 +225,8 @@ end
# Interface # Interface
import ..Flux: Flux, relu import ..Flux: Flux, relu
using .CuArrays.CUDAnative using CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims using CuArrays: @cuindex, cudims
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray) function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src) function kernel(dst, src)

View File

@ -110,7 +110,7 @@ end
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = (a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x) invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = (a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x)) a(T.(x))
""" """

View File

@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin if has_cuarrays()
import .CuArrays: CuArray, cudaconvert import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}() BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()

View File

@ -2,7 +2,7 @@ module Optimise
export train!, export train!,
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser InvDecay, ExpDecay, WeightDecay, stop, Optimiser
include("optimisers.jl") include("optimisers.jl")

View File

@ -23,7 +23,7 @@ function apply!(o::Descent, x, Δ)
end end
""" """
Momentum(params, η = 0.01; ρ = 0.9) Momentum(η = 0.01; ρ = 0.9)
Gradient descent with learning rate `η` and momentum `ρ`. Gradient descent with learning rate `η` and momentum `ρ`.
""" """
@ -108,6 +108,36 @@ function apply!(o::ADAM, x, Δ)
return Δ return Δ
end end
"""
RADAM(η = 0.001, β = (0.9, 0.999))
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
"""
mutable struct RADAM
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end
RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
function apply!(o::RADAM, x, Δ)
η, β = o.eta, o.beta
ρ∞ = 2/(1-β[2])-1
mt, vt, βp, t = get!(o.state, x, (zero(x), zero(x), β, 1))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
ρ = ρ∞ - 2t*βp[2]/(1-βp[2])
if ρ > 4
r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η * r
else
@. Δ = mt / (1 - βp[1]) * η
end
o.state[x] = (mt, vt, βp .* β, t+1)
return Δ
end
""" """
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08) AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)

View File

@ -31,6 +31,7 @@ end
function prefor(f, x; seen = IdSet()) function prefor(f, x; seen = IdSet())
x seen && return x seen && return
push!(seen, x)
f(x) f(x)
foreach(x -> prefor(f, x, seen = seen), children(x)) foreach(x -> prefor(f, x, seen = seen), children(x))
return return
@ -59,10 +60,10 @@ end
cpu(m) = mapleaves(x -> adapt(Array, x), m) cpu(m) = mapleaves(x -> adapt(Array, x), m)
gpu_adaptor = identity const gpu_adaptor = if has_cuarrays()
CuArrays.cu
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin else
global gpu_adaptor = CuArrays.cu identity
end end
gpu(x) = mapleaves(gpu_adaptor, x) gpu(x) = mapleaves(gpu_adaptor, x)

View File

@ -6,7 +6,7 @@ using Test
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(), NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
Momentum()] Momentum()]
w = randn(10, 10) w = randn(10, 10)
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)

View File

@ -26,8 +26,10 @@ include("layers/conv.jl")
include("gradients.jl") include("gradients.jl")
if Base.find_package("CuArrays") != nothing if isdefined(Flux, :CUDA)
include("cuda/cuda.jl") include("cuda/cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end end
end end

View File

@ -76,6 +76,16 @@ end
@test size.(params(m)) == [(5, 10), (5,)] @test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5) m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
# Layer duplicated in same chain, params just once pls.
c = Chain(m, m)
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
Flux.children(a::Vector{Any}) = Tuple(a)
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
end end
@testset "Basic Stacking" begin @testset "Basic Stacking" begin