Merge branch 'master' into zygote
This commit is contained in:
commit
67c38b3099
|
@ -33,12 +33,35 @@ git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
|
|||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.6"
|
||||
|
||||
[[CEnum]]
|
||||
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
|
||||
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
|
||||
version = "0.2.0"
|
||||
|
||||
[[CSTParser]]
|
||||
deps = ["Tokenize"]
|
||||
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
|
||||
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
|
||||
version = "0.6.2"
|
||||
|
||||
[[CUDAapi]]
|
||||
deps = ["Libdl", "Logging"]
|
||||
git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091"
|
||||
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
||||
version = "1.1.0"
|
||||
|
||||
[[CUDAdrv]]
|
||||
deps = ["CUDAapi", "Libdl", "Printf"]
|
||||
git-tree-sha1 = "9ce99b5732c70e06ed97c042187baed876fb1698"
|
||||
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
|
||||
version = "3.1.0"
|
||||
|
||||
[[CUDAnative]]
|
||||
deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"]
|
||||
git-tree-sha1 = "0a00bef482b7c9127495c7f4a2a85e73b13b5af8"
|
||||
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
|
||||
version = "2.3.0"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
|
||||
git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e"
|
||||
|
@ -81,6 +104,12 @@ git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
|
|||
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
|
||||
version = "4.0.0"
|
||||
|
||||
[[CuArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
||||
git-tree-sha1 = "46b48742a84bb839e74215b7e468a4a1c6ba30f9"
|
||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
version = "1.2.1"
|
||||
|
||||
[[DataAPI]]
|
||||
git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0"
|
||||
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
|
||||
|
@ -124,9 +153,9 @@ version = "0.3.0"
|
|||
|
||||
[[FillArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
||||
git-tree-sha1 = "4c707c87ddd3199fc5624d5c98b2c706e4d00675"
|
||||
git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad"
|
||||
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
||||
version = "0.7.0"
|
||||
version = "0.6.4"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
|
||||
|
@ -139,6 +168,12 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
|||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.3"
|
||||
|
||||
[[GPUArrays]]
|
||||
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "dd169c636d1d3656a9faca772f5bd7c226a61254"
|
||||
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
||||
version = "1.0.1"
|
||||
|
||||
[[IRTools]]
|
||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||
git-tree-sha1 = "e23faa71b8f54c3fdc99b230b9c2906cafdddca5"
|
||||
|
@ -161,6 +196,12 @@ git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8"
|
|||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.7.2"
|
||||
|
||||
[[LLVM]]
|
||||
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
|
||||
git-tree-sha1 = "52cfea426bd248a427aace7d88eb5d45b84ea297"
|
||||
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
|
||||
version = "1.2.0"
|
||||
|
||||
[[LibGit2]]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
||||
|
|
1
NEWS.md
1
NEWS.md
|
@ -1,6 +1,7 @@
|
|||
# v0.9.0
|
||||
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
|
||||
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
|
||||
* New [RADAM](https://github.com/FluxML/Flux.jl/pull/842) optimiser.
|
||||
|
||||
# v0.8.0
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
name = "Flux"
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.8.3"
|
||||
version = "0.9.0"
|
||||
|
||||
[deps]
|
||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
|
||||
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
|
@ -16,7 +18,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
|||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
|
||||
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
|
||||
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
|
@ -25,6 +26,8 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
|||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
||||
[compat]
|
||||
CUDAapi = "1.1"
|
||||
CuArrays = "1.2"
|
||||
NNlib = "0.6"
|
||||
Zygote = "0.3"
|
||||
julia = "1.1"
|
||||
|
|
|
@ -14,8 +14,8 @@ Which means allocations occur much faster.
|
|||
And you use less memory.
|
||||
|
||||
|
||||
## Make sure your custom activation functions preserve the type of their inputs
|
||||
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||
## Make sure your activation and loss functions preserve the type of their inputs
|
||||
Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
|
||||
they should also preserve the type of their inputs.
|
||||
|
||||
A very artificial example using an activation function like
|
||||
|
@ -26,6 +26,7 @@ A very artificial example using an activation function like
|
|||
|
||||
will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
|
||||
because it results in having to use slow mixed type multiplication in the dense layers.
|
||||
Similar situations can occur in the loss function during backpropagation.
|
||||
|
||||
Which means if you change your data say from `Float64` to `Float32` (which should give a speedup: see above),
|
||||
you will see a large slow-down
|
||||
|
@ -60,7 +61,7 @@ end
|
|||
|
||||
It is much faster to concatenate them into a matrix,
|
||||
as this will hit BLAS matrix-matrix multiplication, which is much faster than the equivalent sequence of matrix-vector multiplications.
|
||||
Even though this means allocating new memory to store them contiguously.
|
||||
The improvement is enough that it is worthwhile allocating new memory to store them contiguously.
|
||||
|
||||
```julia
|
||||
x_batch = reduce(hcat, xs)
|
||||
|
|
21
src/Flux.jl
21
src/Flux.jl
|
@ -3,7 +3,7 @@ module Flux
|
|||
# Zero Flux Given
|
||||
|
||||
using Base: tail
|
||||
using Zygote, MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
@reexport using NNlib
|
||||
using Zygote: Params, @adjoint, gradient, forward
|
||||
|
@ -18,7 +18,20 @@ using .Optimise
|
|||
using .Optimise: @epochs
|
||||
export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
|
||||
ADAMW, InvDecay, ExpDecay, WeightDecay
|
||||
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
|
||||
|
||||
using CUDAapi
|
||||
if has_cuda()
|
||||
try
|
||||
using CuArrays
|
||||
@eval has_cuarrays() = true
|
||||
catch ex
|
||||
@warn "CUDA is installed, but CuArrays.jl fails to load" exception=(ex,catch_backtrace())
|
||||
@eval has_cuarrays() = false
|
||||
end
|
||||
else
|
||||
has_cuarrays() = false
|
||||
end
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
|
@ -34,6 +47,8 @@ include("data/Data.jl")
|
|||
|
||||
include("deprecations.jl")
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda/cuda.jl")
|
||||
if has_cuarrays()
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
|
||||
end # module
|
||||
|
|
|
@ -1,38 +1,12 @@
|
|||
module CUDA
|
||||
|
||||
using ..CuArrays
|
||||
import ..CuArrays.CUDAdrv: CuPtr, CU_NULL
|
||||
using Pkg.TOML
|
||||
|
||||
function version_check()
|
||||
major_version = 1
|
||||
project = joinpath(dirname(pathof(CuArrays)), "../Project.toml")
|
||||
project = TOML.parse(String(read(project)))
|
||||
version = VersionNumber(get(project, "version", "0.0.0"))
|
||||
if version.major != major_version
|
||||
@warn """
|
||||
Flux is only supported with CuArrays v$major_version.x.
|
||||
Try running `] pin CuArrays@$major_version`.
|
||||
"""
|
||||
end
|
||||
end
|
||||
|
||||
version_check()
|
||||
|
||||
if !applicable(CuArray{UInt8}, undef, 1)
|
||||
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
|
||||
end
|
||||
|
||||
if CuArrays.libcudnn != nothing
|
||||
if isdefined(CuArrays, :libcudnn_handle)
|
||||
handle() = CuArrays.libcudnn_handle[]
|
||||
else
|
||||
handle() = CuArrays.CUDNN.handle()
|
||||
end
|
||||
if CuArrays.libcudnn !== nothing # TODO: use CuArrays.has_cudnn()
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl")
|
||||
else
|
||||
@warn("CUDNN is not installed, some functionality will not be available.")
|
||||
@warn "CUDNN is not installed, some functionality will not be available."
|
||||
end
|
||||
|
||||
end
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
using CuArrays: libcudnn
|
||||
using CuArrays.CUDNN: @check, handle, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||
import CuArrays.CUDAdrv: CuPtr, CU_NULL
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
mutable struct DropoutDesc
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
using CuArrays: libcudnn
|
||||
using CuArrays.CUDNN: @check, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||
|
||||
import CuArrays.CUDAdrv: CuPtr, CU_NULL
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
|
@ -63,7 +67,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
|
||||
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
w = CuArrays.zeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd) do x
|
||||
|
@ -130,8 +134,8 @@ end
|
|||
# TODO: can we just manipulate strides here?
|
||||
# TODO: should use repmat, but this isn't implemented.
|
||||
hBatch(x::AbstractVector, h::CuVector) = h
|
||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* CuArrays.ones(1, size(x, 2))
|
||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CuArrays.ones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||
|
||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||
h = hBatch(x, h_)
|
||||
|
@ -221,8 +225,8 @@ end
|
|||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
using .CuArrays.CUDAnative
|
||||
using .CuArrays: @cuindex, cudims
|
||||
using CuArrays.CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
|
||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
|
|
|
@ -110,7 +110,7 @@ end
|
|||
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
|
|
|
@ -37,7 +37,7 @@ import Adapt: adapt, adapt_structure
|
|||
|
||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
if has_cuarrays()
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
|
|
|
@ -2,7 +2,7 @@ module Optimise
|
|||
|
||||
export train!,
|
||||
SGD, Descent, ADAM, Momentum, Nesterov, RMSProp,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,
|
||||
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
|
||||
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
|
||||
|
||||
include("optimisers.jl")
|
||||
|
|
|
@ -23,7 +23,7 @@ function apply!(o::Descent, x, Δ)
|
|||
end
|
||||
|
||||
"""
|
||||
Momentum(params, η = 0.01; ρ = 0.9)
|
||||
Momentum(η = 0.01; ρ = 0.9)
|
||||
|
||||
Gradient descent with learning rate `η` and momentum `ρ`.
|
||||
"""
|
||||
|
@ -108,6 +108,36 @@ function apply!(o::ADAM, x, Δ)
|
|||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
RADAM(η = 0.001, β = (0.9, 0.999))
|
||||
|
||||
[RADAM](https://arxiv.org/pdf/1908.03265v1.pdf) optimiser (Rectified ADAM).
|
||||
"""
|
||||
mutable struct RADAM
|
||||
eta::Float64
|
||||
beta::Tuple{Float64,Float64}
|
||||
state::IdDict
|
||||
end
|
||||
|
||||
RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
|
||||
|
||||
function apply!(o::RADAM, x, Δ)
|
||||
η, β = o.eta, o.beta
|
||||
ρ∞ = 2/(1-β[2])-1
|
||||
mt, vt, βp, t = get!(o.state, x, (zero(x), zero(x), β, 1))
|
||||
@. mt = β[1] * mt + (1 - β[1]) * Δ
|
||||
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
|
||||
ρ = ρ∞ - 2t*βp[2]/(1-βp[2])
|
||||
if ρ > 4
|
||||
r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
|
||||
@. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r
|
||||
else
|
||||
@. Δ = mt / (1 - βp[1]) * η
|
||||
end
|
||||
o.state[x] = (mt, vt, βp .* β, t+1)
|
||||
return Δ
|
||||
end
|
||||
|
||||
"""
|
||||
AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08)
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ end
|
|||
|
||||
function prefor(f, x; seen = IdSet())
|
||||
x ∈ seen && return
|
||||
push!(seen, x)
|
||||
f(x)
|
||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||
return
|
||||
|
@ -59,10 +60,10 @@ end
|
|||
|
||||
cpu(m) = mapleaves(x -> adapt(Array, x), m)
|
||||
|
||||
gpu_adaptor = identity
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
global gpu_adaptor = CuArrays.cu
|
||||
const gpu_adaptor = if has_cuarrays()
|
||||
CuArrays.cu
|
||||
else
|
||||
identity
|
||||
end
|
||||
|
||||
gpu(x) = mapleaves(gpu_adaptor, x)
|
||||
|
|
|
@ -6,7 +6,7 @@ using Test
|
|||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
Momentum()]
|
||||
w′ = randn(10, 10)
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
|
|
|
@ -26,8 +26,10 @@ include("layers/conv.jl")
|
|||
|
||||
include("gradients.jl")
|
||||
|
||||
if Base.find_package("CuArrays") != nothing
|
||||
if isdefined(Flux, :CUDA)
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
end
|
||||
|
||||
end
|
||||
|
|
|
@ -76,6 +76,16 @@ end
|
|||
@test size.(params(m)) == [(5, 10), (5,)]
|
||||
m = RNN(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
|
||||
# Layer duplicated in same chain, params just once pls.
|
||||
c = Chain(m, m)
|
||||
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
|
||||
# Self-referential array. Just want params, no stack overflow pls.
|
||||
r = Any[nothing,m]
|
||||
Flux.children(a::Vector{Any}) = Tuple(a)
|
||||
r[1] = r
|
||||
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
||||
@testset "Basic Stacking" begin
|
||||
|
|
Loading…
Reference in New Issue