diff --git a/.gitignore b/.gitignore index 785b9c4e..9d6de240 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ docs/build/ docs/site/ docs/flux.css -demos +deps diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 6fbd0792..02225279 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -151,3 +151,13 @@ m = Chain(x -> x^2, x -> x+1) m(5) # => 26 ``` + +## Layer helpers + +Flux provides a set of helpers for custom layers, which you can enable by calling + +```julia +Flux.treelike(Affine) +``` + +This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 28e773b7..1fd87d41 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -30,3 +30,13 @@ leakyrelu elu swish ``` + +## Normalisation & Regularisation + +These layers don't affect the structure of the network but may improve training times or reduce overfitting. + +```@docs +Flux.testmode! +Dropout +LayerNorm +``` diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 3af5604b..56f511e4 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -58,8 +58,5 @@ All optimisers return a function that, when called, will update the parameters p SGD Momentum Nesterov -RMSProp ADAM -ADAGrad -ADADelta ``` diff --git a/src/Flux.jl b/src/Flux.jl index e4f170f2..7671ddd2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,12 +7,12 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, +export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, SGD, ADAM, Momentum, Nesterov, param, params, mapleaves using NNlib -export σ, relu, leakyrelu, elu, swish, softmax +export σ, sigmoid, relu, leakyrelu, elu, swish, softmax include("tracker/Tracker.jl") using .Tracker @@ -22,10 +22,15 @@ using .Optimise include("utils.jl") include("onehot.jl") -include("tree.jl") +include("treelike.jl") include("layers/stateless.jl") include("layers/basic.jl") include("layers/recurrent.jl") +include("layers/normalisation.jl") + +include("data/Data.jl") + +include("batches/Batches.jl") end # module diff --git a/src/batches/Batches.jl b/src/batches/Batches.jl new file mode 100644 index 00000000..066f4d1c --- /dev/null +++ b/src/batches/Batches.jl @@ -0,0 +1,7 @@ +module Batches + +import ..Flux + +include("batch.jl") + +end diff --git a/src/batches/batch.jl b/src/batches/batch.jl new file mode 100644 index 00000000..5a2eb82e --- /dev/null +++ b/src/batches/batch.jl @@ -0,0 +1,8 @@ +struct Batch{T,A,M} + data::A + mask::M +end + +Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask) + +Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs))) diff --git a/src/data/Data.jl b/src/data/Data.jl new file mode 100644 index 00000000..ffea729c --- /dev/null +++ b/src/data/Data.jl @@ -0,0 +1,14 @@ +module Data + +export CMUDict, cmudict + +deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) + +function __init__() + mkpath(deps()) +end + +include("cmudict.jl") +using .CMUDict + +end diff --git a/src/data/cmudict.jl b/src/data/cmudict.jl new file mode 100644 index 00000000..88b9c6c0 --- /dev/null +++ b/src/data/cmudict.jl @@ -0,0 +1,42 @@ +module CMUDict + +export cmudict + +using ..Data: deps + +const version = "0.7b" + +function load() + isdir(deps("cmudict")) && return + mkpath(deps("cmudict")) + for x in ["", ".phones", ".symbols"] + download("http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x", + deps("cmudict", "cmudict$x")) + end +end + +function phones() + load() + Symbol.(first.(split.(split(readstring(deps("cmudict", "cmudict.phones")), + "\n", keep = false), "\t"))) +end + +function symbols() + load() + Symbol.(split(readstring(deps("CMUDict", "cmudict.symbols")), + "\n", keep = false)) +end + +function rawdict() + load() + Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in + filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n")))) +end + +validword(s) = ismatch(r"^[\w-\.]+$", s) + +cmudict() = filter((s, ps) -> validword(s), rawdict()) + +alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.'] + +end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9c8b1016..aa101c43 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -27,7 +27,7 @@ end children(c::Chain) = c.layers mapchildren(f, c::Chain) = Chain(f.(c.layers)...) -(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) +(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) @@ -78,3 +78,32 @@ function Base.show(io::IO, l::Dense) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end + +""" + Diagonal(in::Integer) + +Creates an element-wise linear transformation layer with learnable +vectors `α` and `β`: + + y = α .* x .+ β + +The input `x` must be a array where `size(x, 1) == in`. +""" +struct Diagonal{T} + α::T + β::T +end + +Diagonal(in::Integer; initα = ones, initβ = zeros) = + Diagonal(param(initα(in)), param(initβ(in))) + +treelike(Diagonal) + +function (a::Diagonal)(x) + α, β = a.α, a.β + α.*x .+ β +end + +function Base.show(io::IO, l::Diagonal) + print(io, "Diagonal(", length(l.α), ")") +end diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl new file mode 100644 index 00000000..d296b0a3 --- /dev/null +++ b/src/layers/normalisation.jl @@ -0,0 +1,67 @@ +""" + testmode!(m) + testmode!(m, false) + +Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to +training mode with `false`). +""" +function testmode!(m, val::Bool=true) + prefor(x -> _testmode!(x, val), m) + return m +end + +_testmode!(m, test) = nothing + +""" + Dropout(p) + +A Dropout layer. For each input, either sets that input to `0` (with probability +`p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it +reduces overfitting during training. + +Does nothing to the input once in [`testmode!`](@ref). +""" +mutable struct Dropout{F} + p::F + active::Bool +end + +function Dropout(p) + @assert 0 ≤ p ≤ 1 + Dropout{typeof(p)}(p, true) +end + +function (a::Dropout)(x) + a.active || return x + y = similar(x) + rand!(y) + q = 1 - a.p + @inbounds for i=1:length(y) + y[i] = y[i] > a.p ? 1 / q : 0 + end + return y .* x +end + +_testmode!(a::Dropout, test) = (a.active = !test) + +""" + LayerNorm(h::Integer) + +A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be +used with recurrent hidden states of size `h`. Normalises the mean/stddev of +each input before applying a per-neuron gain/bias. +""" +struct LayerNorm{T} + diag::Diagonal{T} +end + +LayerNorm(h::Integer) = + LayerNorm(Diagonal(h)) + +treelike(LayerNorm) + +(a::LayerNorm)(x) = a.diag(normalise(x)) + +function Base.show(io::IO, l::LayerNorm) + print(io, "LayerNorm(", length(l.diag.α), ")") +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 716bc574..599776ce 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,5 +1,7 @@ # TODO: broadcasting cat -combine(x, h) = vcat(x, h .* trues(1, size(x, 2))) +combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2))) +combine(x::AbstractVector, h::AbstractVector) = vcat(x, h) +combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h) # Stateful recurrence diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 3931c216..edbdec58 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,14 +1,27 @@ +using NNlib: log_fast + # Cost functions mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) = - -sum(y .* log.(ŷ)) / size(y, 2) + -sum(y .* log_fast.(ŷ)) / size(y, 2) @deprecate logloss(x, y) crossentropy(x, y) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) logŷ = logŷ .- maximum(logŷ, 1) - ypred = logŷ .- log.(sum(exp.(logŷ), 1)) + ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1)) -sum(y .* ypred) / size(y, 2) end + +""" + normalise(x::AbstractVecOrMat) + +Normalise each column of `x` to mean 0 and standard deviation 1. +""" +function normalise(x::AbstractVecOrMat) + μ′ = mean(x, 1) + σ′ = std(x, 1, mean = μ′) + return (x .- μ′) ./ σ′ +end diff --git a/src/onehot.jl b/src/onehot.jl index 5414773c..f94fb93e 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -1,3 +1,5 @@ +import Base: * + struct OneHotVector <: AbstractVector{Bool} ix::UInt32 of::UInt32 @@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix -Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] +A::AbstractMatrix * b::OneHotVector = A[:, b.ix] struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} height::Int @@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] -Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] +A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -40,10 +42,22 @@ function onehot(l, labels) OneHotVector(i, length(labels)) end -onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) +function onehot(l, labels, unk) + i = findfirst(labels, l) + i > 0 || return onehot(unk, labels) + OneHotVector(i, length(labels)) +end + +onehotbatch(ls, labels, unk...) = + OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) argmax(y::AbstractVector, labels = 1:length(y)) = labels[findfirst(y, maximum(y))] argmax(y::AbstractMatrix, l...) = squeeze(mapslices(y -> argmax(y, l...), y, 1), 1) + +# Ambiguity hack + +a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b)) +a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 95b31b98..abc54090 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -38,7 +38,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 - @. p.Δ /= √acc * η + @. p.Δ *= η / √acc end end @@ -46,7 +46,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc += p.Δ ^ 2 - @. p.Δ /= √acc * η + @. p.Δ *= η / √acc end end diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 2a2ec5eb..0809e86b 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,8 +1,8 @@ using Juno using Flux.Tracker: back! -tocb(f) = f -tocb(fs::AbstractVector) = () -> foreach(call, fs) +runall(f) = f +runall(fs::AbstractVector) = () -> foreach(call, fs) """ train!(loss, data, opt; cb = () -> ()) @@ -11,10 +11,11 @@ For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt` and the callback `cb` (i.e. `opt()` and `cb()`). -Multiple callbacks can be passed to `cb` as an array. +Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, data, opt; cb = () -> ()) - cb = tocb(cb) + cb = runall(cb) + opt = runall(opt) @progress for d in data l = loss(d...) isinf(l.data[]) && error("Loss is Inf") diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 8e6a584a..74ed2d75 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,6 +1,6 @@ module Tracker -export TrackedArray, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, param, back! data(x) = x istracked(x) = false @@ -38,7 +38,9 @@ TrackedArray(c::Call) = TrackedArray(c, c()) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) -param(xs) = TrackedArray(AbstractFloat.(xs)) +isleaf(x::TrackedArray) = x.f == Call(nothing) + +param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs)) param(xs::Real) = param(fill(xs)) istracked(x::TrackedArray) = true @@ -56,6 +58,18 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) +value(x) = x +value(x::TrackedArray) = data(x) +value(x::TrackedScalar) = data(x)[] + +Base.:(==)(x::TrackedArray, y) = value(x) == y +Base.:(==)(y, x::TrackedArray) = y == value(x) +Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x) + +Base.isless(x::TrackedScalar, y) = isless(value(x), y) +Base.isless(x, y::TrackedScalar) = isless(x, value(y)) +Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y)) + Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = print(io, "TrackedArray{…,$A}") @@ -70,6 +84,9 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru end end +Base.setindex!(xs::TrackedArray, v, i...) = + error("Can't differentiate `setindex!`") + include("back.jl") include("lib.jl") include("numeric.jl") diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 95bfc1d3..8ac386d8 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -1,5 +1,3 @@ -import Base: * - toarray(xs::AbstractArray, ys::AbstractArray) = ys toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y @@ -59,21 +57,58 @@ back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) +Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) +Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) + +# Hacks to get std working +Base.std(x::TrackedArray; mean = Base.mean(x)) = + sqrt.(sum((x .- mean).^2) ./ (length(x)-1)) +Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) = + sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1)) + +back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data)) +back(::typeof(mean), Δ, xs::TrackedArray, region) = + back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) + # BLAS -a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) -a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) -a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) +for f in :[*, Ac_mul_B].args + @eval begin + import Base.$f + $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) + $f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b)) + $f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) -a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) -a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b)) -a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) + $f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b)) + $f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b)) + $f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b)) + + $f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b)) + $f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b)) + $f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b)) + end +end function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) @back(a, A_mul_Bt(Δ, data(b))) @back(b, At_mul_B(data(a), Δ)) end +function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) + @back(a, A_mul_Bt(Δ, data(b))') + @back(b, *(data(a), Δ)) +end + +# Fast path for matrix-vector +function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector) + if isleaf(W) + W.grad .+= Δ .* data(x).' + else + back(W, A_mul_Bt(Δ, data(x))) + end + @back(x, At_mul_B(data(W), Δ)) +end + # NNlib import NNlib: softmax, ∇softmax diff --git a/src/tree.jl b/src/treelike.jl similarity index 89% rename from src/tree.jl rename to src/treelike.jl index efdf9101..097ccdc6 100644 --- a/src/tree.jl +++ b/src/treelike.jl @@ -1,6 +1,9 @@ children(x) = () mapchildren(f, x) = x +children(x::Tuple) = x +mapchildren(f, x::Tuple) = map(f, x) + function treelike(T, fs = fieldnames(T)) @eval begin children(x::$T) = ($([:(x.$f) for f in fs]...),) @@ -32,3 +35,5 @@ function params(m) prefor(p -> p isa TrackedArray && push!(ps, p), m) return ps end + +params(m...) = params(m) diff --git a/test/data.jl b/test/data.jl new file mode 100644 index 00000000..1b93ab3c --- /dev/null +++ b/test/data.jl @@ -0,0 +1,3 @@ +using Flux.Data + +@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl new file mode 100644 index 00000000..5a302a51 --- /dev/null +++ b/test/layers/normalisation.jl @@ -0,0 +1,28 @@ +using Flux: testmode! + +@testset "Dropout" begin + x = [1.,2.,3.] + @test x == testmode!(Dropout(0.1))(x) + @test x == Dropout(0)(x) + @test zeros(x) == Dropout(1)(x) + + x = rand(100) + m = Dropout(0.9) + y = m(x) + @test count(a->a==0, y) > 50 + testmode!(m) + y = m(x) + @test count(a->a==0, y) == 0 + testmode!(m, false) + y = m(x) + @test count(a->a==0, y) > 50 + + x = rand(100) + m = Chain(Dense(100,100), + Dropout(0.9)) + y = m(x) + @test count(a->a == 0, y) > 50 + testmode!(m) + y = m(x) + @test count(a->a == 0, y) == 0 +end diff --git a/test/runtests.jl b/test/runtests.jl index 2ab0e447..efd1a462 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,5 +4,6 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") +include("layers/normalisation.jl") end diff --git a/test/tracker.jl b/test/tracker.jl index 2a20338e..81a72566 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) +@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) + @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> softmax(x).*(1:3), 3) @@ -22,23 +24,22 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3)) @test gradtest(vcat, rand(2,3), rand(3,3)) +@testset "mean" begin + @test gradtest(mean, rand(2, 3)) + + @test gradtest(x -> mean(x, 1), rand(2, 3)) + @test gradtest(x -> mean(x, 2), rand(2, 3)) + @test gradtest(x -> mean(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) +end + +@test gradtest(x -> std(x), rand(5,5)) +@test gradtest(x -> std(x, 1), rand(5,5)) + @test gradtest(rand(5)) do x y = x.^2 2y + x end -for T in [Float32, Float64] - @test isa(param(T(1)), TrackedArray{T, 0}) - @test isa(param(rand(T, 2)), TrackedArray{T, 1}) - @test isa(param(rand(T, 2,2)), TrackedArray{T, 2}) -end - -# TODO: do we wand this behaviour ?? -F = typeof(AbstractFloat(1)) -for T in [Int32, Int64] - @test isa(param(T(1)), TrackedArray{F, 0}) - @test isa(param(rand(T, 2)), TrackedArray{F, 1}) - @test isa(param(rand(T, 2,2)), TrackedArray{F, 2}) -end - end #testset