Merge branch 'master' into cat-fix

This commit is contained in:
GenaBitu 2017-12-05 11:13:29 +01:00
commit 62b3600eca
No known key found for this signature in database
GPG Key ID: 6E647E317A9DD426
23 changed files with 353 additions and 44 deletions

2
.gitignore vendored
View File

@ -4,4 +4,4 @@
docs/build/ docs/build/
docs/site/ docs/site/
docs/flux.css docs/flux.css
demos deps

View File

@ -151,3 +151,13 @@ m = Chain(x -> x^2, x -> x+1)
m(5) # => 26 m(5) # => 26
``` ```
## Layer helpers
Flux provides a set of helpers for custom layers, which you can enable by calling
```julia
Flux.treelike(Affine)
```
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

View File

@ -30,3 +30,13 @@ leakyrelu
elu elu
swish swish
``` ```
## Normalisation & Regularisation
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
```@docs
Flux.testmode!
Dropout
LayerNorm
```

View File

@ -58,8 +58,5 @@ All optimisers return a function that, when called, will update the parameters p
SGD SGD
Momentum Momentum
Nesterov Nesterov
RMSProp
ADAM ADAM
ADAGrad
ADADelta
``` ```

View File

@ -7,12 +7,12 @@ module Flux
using Juno, Requires using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, export Chain, Dense, RNN, LSTM, Dropout, LayerNorm,
SGD, ADAM, Momentum, Nesterov, SGD, ADAM, Momentum, Nesterov,
param, params, mapleaves param, params, mapleaves
using NNlib using NNlib
export σ, relu, leakyrelu, elu, swish, softmax export σ, sigmoid, relu, leakyrelu, elu, swish, softmax
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
using .Tracker using .Tracker
@ -22,10 +22,15 @@ using .Optimise
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")
include("tree.jl") include("treelike.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/basic.jl") include("layers/basic.jl")
include("layers/recurrent.jl") include("layers/recurrent.jl")
include("layers/normalisation.jl")
include("data/Data.jl")
include("batches/Batches.jl")
end # module end # module

7
src/batches/Batches.jl Normal file
View File

@ -0,0 +1,7 @@
module Batches
import ..Flux
include("batch.jl")
end

8
src/batches/batch.jl Normal file
View File

@ -0,0 +1,8 @@
struct Batch{T,A,M}
data::A
mask::M
end
Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask)
Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs)))

14
src/data/Data.jl Normal file
View File

@ -0,0 +1,14 @@
module Data
export CMUDict, cmudict
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
function __init__()
mkpath(deps())
end
include("cmudict.jl")
using .CMUDict
end

42
src/data/cmudict.jl Normal file
View File

@ -0,0 +1,42 @@
module CMUDict
export cmudict
using ..Data: deps
const version = "0.7b"
function load()
isdir(deps("cmudict")) && return
mkpath(deps("cmudict"))
for x in ["", ".phones", ".symbols"]
download("http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-$version$x",
deps("cmudict", "cmudict$x"))
end
end
function phones()
load()
Symbol.(first.(split.(split(readstring(deps("cmudict", "cmudict.phones")),
"\n", keep = false), "\t")))
end
function symbols()
load()
Symbol.(split(readstring(deps("CMUDict", "cmudict.symbols")),
"\n", keep = false))
end
function rawdict()
load()
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
filter(!isempty, split.(split(readstring(deps("CMUDict", "cmudict")), "\n"))))
end
validword(s) = ismatch(r"^[\w-\.]+$", s)
cmudict() = filter((s, ps) -> validword(s), rawdict())
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
end

View File

@ -27,7 +27,7 @@ end
children(c::Chain) = c.layers children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...) mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) (c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
@ -78,3 +78,32 @@ function Base.show(io::IO, l::Dense)
l.σ == identity || print(io, ", ", l.σ) l.σ == identity || print(io, ", ", l.σ)
print(io, ")") print(io, ")")
end end
"""
Diagonal(in::Integer)
Creates an element-wise linear transformation layer with learnable
vectors `α` and `β`:
y = α .* x .+ β
The input `x` must be a array where `size(x, 1) == in`.
"""
struct Diagonal{T}
α::T
β::T
end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(param(initα(in)), param(initβ(in)))
treelike(Diagonal)
function (a::Diagonal)(x)
α, β = a.α, a.β
α.*x .+ β
end
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end

View File

@ -0,0 +1,67 @@
"""
testmode!(m)
testmode!(m, false)
Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to
training mode with `false`).
"""
function testmode!(m, val::Bool=true)
prefor(x -> _testmode!(x, val), m)
return m
end
_testmode!(m, test) = nothing
"""
Dropout(p)
A Dropout layer. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it
reduces overfitting during training.
Does nothing to the input once in [`testmode!`](@ref).
"""
mutable struct Dropout{F}
p::F
active::Bool
end
function Dropout(p)
@assert 0 p 1
Dropout{typeof(p)}(p, true)
end
function (a::Dropout)(x)
a.active || return x
y = similar(x)
rand!(y)
q = 1 - a.p
@inbounds for i=1:length(y)
y[i] = y[i] > a.p ? 1 / q : 0
end
return y .* x
end
_testmode!(a::Dropout, test) = (a.active = !test)
"""
LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
each input before applying a per-neuron gain/bias.
"""
struct LayerNorm{T}
diag::Diagonal{T}
end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))
treelike(LayerNorm)
(a::LayerNorm)(x) = a.diag(normalise(x))
function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", length(l.diag.α), ")")
end

View File

@ -1,5 +1,7 @@
# TODO: broadcasting cat # TODO: broadcasting cat
combine(x, h) = vcat(x, h .* trues(1, size(x, 2))) combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2)))
combine(x::AbstractVector, h::AbstractVector) = vcat(x, h)
combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h)
# Stateful recurrence # Stateful recurrence

View File

@ -1,14 +1,27 @@
using NNlib: log_fast
# Cost functions # Cost functions
mse(, y) = sum(( .- y).^2)/length(y) mse(, y) = sum(( .- y).^2)/length(y)
crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) = crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat) =
-sum(y .* log.()) / size(y, 2) -sum(y .* log_fast.()) / size(y, 2)
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
logŷ = logŷ .- maximum(logŷ, 1) logŷ = logŷ .- maximum(logŷ, 1)
ypred = logŷ .- log.(sum(exp.(logŷ), 1)) ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2) -sum(y .* ypred) / size(y, 2)
end end
"""
normalise(x::AbstractVecOrMat)
Normalise each column of `x` to mean 0 and standard deviation 1.
"""
function normalise(x::AbstractVecOrMat)
μ′ = mean(x, 1)
σ = std(x, 1, mean = μ′)
return (x .- μ′) ./ σ
end

View File

@ -1,3 +1,5 @@
import Base: *
struct OneHotVector <: AbstractVector{Bool} struct OneHotVector <: AbstractVector{Bool}
ix::UInt32 ix::UInt32
of::UInt32 of::UInt32
@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
height::Int height::Int
@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
@ -40,10 +42,22 @@ function onehot(l, labels)
OneHotVector(i, length(labels)) OneHotVector(i, length(labels))
end end
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) function onehot(l, labels, unk)
i = findfirst(labels, l)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
argmax(y::AbstractVector, labels = 1:length(y)) = argmax(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))] labels[findfirst(y, maximum(y))]
argmax(y::AbstractMatrix, l...) = argmax(y::AbstractMatrix, l...) =
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1) squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
# Ambiguity hack
a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b))
a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b))

View File

@ -38,7 +38,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x) .+ ϵ
function () function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2 @. acc = ρ * acc + (1 - ρ) * p.Δ ^ 2
@. p.Δ /= acc * η @. p.Δ *= η / acc
end end
end end
@ -46,7 +46,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x) .+ ϵ
function () function ()
@. acc += p.Δ ^ 2 @. acc += p.Δ ^ 2
@. p.Δ /= acc * η @. p.Δ *= η / acc
end end
end end

View File

@ -1,8 +1,8 @@
using Juno using Juno
using Flux.Tracker: back! using Flux.Tracker: back!
tocb(f) = f runall(f) = f
tocb(fs::AbstractVector) = () -> foreach(call, fs) runall(fs::AbstractVector) = () -> foreach(call, fs)
""" """
train!(loss, data, opt; cb = () -> ()) train!(loss, data, opt; cb = () -> ())
@ -11,10 +11,11 @@ For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt` and the callback `cb` backpropagation and calls the optimizer `opt` and the callback `cb`
(i.e. `opt()` and `cb()`). (i.e. `opt()` and `cb()`).
Multiple callbacks can be passed to `cb` as an array. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
""" """
function train!(loss, data, opt; cb = () -> ()) function train!(loss, data, opt; cb = () -> ())
cb = tocb(cb) cb = runall(cb)
opt = runall(opt)
@progress for d in data @progress for d in data
l = loss(d...) l = loss(d...)
isinf(l.data[]) && error("Loss is Inf") isinf(l.data[]) && error("Loss is Inf")

View File

@ -1,6 +1,6 @@
module Tracker module Tracker
export TrackedArray, param, back! export TrackedArray, TrackedVector, TrackedMatrix, param, back!
data(x) = x data(x) = x
istracked(x) = false istracked(x) = false
@ -38,7 +38,9 @@ TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs)) isleaf(x::TrackedArray) = x.f == Call(nothing)
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
param(xs::Real) = param(fill(xs)) param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true istracked(x::TrackedArray) = true
@ -56,6 +58,18 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
value(x) = x
value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[]
Base.:(==)(x::TrackedArray, y) = value(x) == y
Base.:(==)(y, x::TrackedArray) = y == value(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}") print(io, "TrackedArray{…,$A}")
@ -70,6 +84,9 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru
end end
end end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
include("back.jl") include("back.jl")
include("lib.jl") include("lib.jl")
include("numeric.jl") include("numeric.jl")

View File

@ -1,5 +1,3 @@
import Base: *
toarray(xs::AbstractArray, ys::AbstractArray) = ys toarray(xs::AbstractArray, ys::AbstractArray) = ys
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
@ -59,21 +57,58 @@ back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .=
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
# BLAS # BLAS
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) for f in :[*, Ac_mul_B].args
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) @eval begin
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) $f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b)) $f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) $f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
end
end
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back(a, A_mul_Bt(Δ, data(b))) @back(a, A_mul_Bt(Δ, data(b)))
@back(b, At_mul_B(data(a), Δ)) @back(b, At_mul_B(data(a), Δ))
end end
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
@back(a, A_mul_Bt(Δ, data(b))')
@back(b, *(data(a), Δ))
end
# Fast path for matrix-vector
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
if isleaf(W)
W.grad .+= Δ .* data(x).'
else
back(W, A_mul_Bt(Δ, data(x)))
end
@back(x, At_mul_B(data(W), Δ))
end
# NNlib # NNlib
import NNlib: softmax, ∇softmax import NNlib: softmax, ∇softmax

View File

@ -1,6 +1,9 @@
children(x) = () children(x) = ()
mapchildren(f, x) = x mapchildren(f, x) = x
children(x::Tuple) = x
mapchildren(f, x::Tuple) = map(f, x)
function treelike(T, fs = fieldnames(T)) function treelike(T, fs = fieldnames(T))
@eval begin @eval begin
children(x::$T) = ($([:(x.$f) for f in fs]...),) children(x::$T) = ($([:(x.$f) for f in fs]...),)
@ -32,3 +35,5 @@ function params(m)
prefor(p -> p isa TrackedArray && push!(ps, p), m) prefor(p -> p isa TrackedArray && push!(ps, p), m)
return ps return ps
end end
params(m...) = params(m)

3
test/data.jl Normal file
View File

@ -0,0 +1,3 @@
using Flux.Data
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args

View File

@ -0,0 +1,28 @@
using Flux: testmode!
@testset "Dropout" begin
x = [1.,2.,3.]
@test x == testmode!(Dropout(0.1))(x)
@test x == Dropout(0)(x)
@test zeros(x) == Dropout(1)(x)
x = rand(100)
m = Dropout(0.9)
y = m(x)
@test count(a->a==0, y) > 50
testmode!(m)
y = m(x)
@test count(a->a==0, y) == 0
testmode!(m, false)
y = m(x)
@test count(a->a==0, y) > 50
x = rand(100)
m = Chain(Dense(100,100),
Dropout(0.9))
y = m(x)
@test count(a->a == 0, y) > 50
testmode!(m)
y = m(x)
@test count(a->a == 0, y) == 0
end

View File

@ -4,5 +4,6 @@ using Flux, Base.Test
include("utils.jl") include("utils.jl")
include("tracker.jl") include("tracker.jl")
include("layers/normalisation.jl")
end end

View File

@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), 3)
@ -22,23 +24,22 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(vcat, rand(5), rand(3)) @test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(2,3), rand(3,3)) @test gradtest(vcat, rand(2,3), rand(3,3))
@testset "mean" begin
@test gradtest(mean, rand(2, 3))
@test gradtest(x -> mean(x, 1), rand(2, 3))
@test gradtest(x -> mean(x, 2), rand(2, 3))
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest(rand(5)) do x @test gradtest(rand(5)) do x
y = x.^2 y = x.^2
2y + x 2y + x
end end
for T in [Float32, Float64]
@test isa(param(T(1)), TrackedArray{T, 0})
@test isa(param(rand(T, 2)), TrackedArray{T, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{T, 2})
end
# TODO: do we wand this behaviour ??
F = typeof(AbstractFloat(1))
for T in [Int32, Int64]
@test isa(param(T(1)), TrackedArray{F, 0})
@test isa(param(rand(T, 2)), TrackedArray{F, 1})
@test isa(param(rand(T, 2,2)), TrackedArray{F, 2})
end
end #testset end #testset