diff --git a/README.md b/README.md index 0baa74d4..10d611f1 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i ```julia model = Chain( - Dense(768, 128), + Dense(768, 128, σ), LSTM(128, 256) LSTM(256, 128) Dense(128, 10), diff --git a/docs/src/gpu.md b/docs/src/gpu.md index 253904ad..6be2d7b0 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. +(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.) + ```julia using CuArrays diff --git a/docs/src/index.md b/docs/src/index.md index 86c9c3dc..afeb2075 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly ``` Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models. + +See [GPU support](gpu.md) for more details on installing and using Flux with GPUs. diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 02225279..96efc7b8 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -28,13 +28,15 @@ l = loss(x, y) back!(l) ``` -`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. +`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. ```julia -W.grad +using Flux.Tracker: grad, update! -# Update the parameter -W.data .-= 0.1(W.grad) +Δ = grad(W) + +# Update the parameter and reset the gradient +update!(W, -0.1Δ) loss(x, y) # ~ 2.5 ``` diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index d4325a53..cd53544f 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -7,6 +7,7 @@ add the result to the overall loss. For example, say we have a simple regression. ```julia +using Flux: crossentropy m = Dense(10, 5) loss(x, y) = crossentropy(softmax(m(x)), y) ``` @@ -43,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m)) loss(rand(28^2), rand(10)) ``` + +One can also easily add per-layer regularisation via the `activations` function: + +```julia +julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax) +Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax) + +julia> activations(c, rand(10)) +3-element Array{Any,1}: + param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074]) + param([0.0330606, -0.456104]) + param([0.61991, 0.38009]) + +julia> sum(vecnorm, ans) +2.639678767773633 (tracked) +``` diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 56f511e4..ac58f6d0 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -17,16 +17,17 @@ back!(l) We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: ```julia -function update() +using Flux.Tracker: grad, update! + +function sgd() η = 0.1 # Learning Rate for p in (W, b) - p.data .-= η .* p.grad # Apply the update - p.grad .= 0 # Clear the gradient + update!(p, -η * grad(p)) end end ``` -If we call `update`, the parameters `W` and `b` will change and our loss should go down. +If we call `sgd`, the parameters `W` and `b` will change and our loss should go down. There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum. diff --git a/src/Flux.jl b/src/Flux.jl index eeda5492..0d78024b 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -23,7 +23,7 @@ include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, - RMSProp, ADAGrad, ADADelta, AMSGrad + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM include("utils.jl") include("onehot.jl") diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ad374643..cf89df41 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain) print(io, ")") end +# Seem to need this for `accumulate`; try removing on 0.7 +Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any + +activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers) + """ Dense(in::Integer, out::Integer, σ = identity) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 39d3394d..38310aad 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,5 +1,10 @@ using NNlib: conv +@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)}) + +expand(N, i::Tuple) = i +expand(N, i::Integer) = ntuple(_ -> i, N) + """ Conv(size, in=>out) Conv(size, in=>out, relu) @@ -10,7 +15,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`. Data should be stored in WHCN order. In other words, a 100×100 RGB image would be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. -Takes the keyword arguments `pad` and `stride`. +Takes the keyword arguments `pad`, `stride` and `dilation`. """ struct Conv{N,F,A,V} σ::F @@ -18,17 +23,17 @@ struct Conv{N,F,A,V} bias::V stride::NTuple{N,Int} pad::NTuple{N,Int} + dilation::NTuple{N,Int} end -Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0) where T = - Conv(σ, w, b, stride, pad) +Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} = + Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, - stride::NTuple{N,Integer} = map(_->1,k), - pad::NTuple{N,Integer} = map(_->0,k)) where N = + stride = 1, pad = 0, dilation = 1) where N = Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, - stride = stride, pad = pad) + stride = stride, pad = pad, dilation = dilation) Flux.treelike(Conv) @@ -36,7 +41,7 @@ function (c::Conv)(x) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) + σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) end function Base.show(io::IO, l::Conv) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 74905a36..54f5eb56 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -31,15 +31,14 @@ function Dropout(p) Dropout{typeof(p)}(p, true) end +_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) + function (a::Dropout)(x) a.active || return x y = similar(x) rand!(y) - q = 1 - a.p - @inbounds for i=1:length(y) - y[i] = y[i] > a.p ? 1 / q : 0 - end - return y .* x + y .= _dropout_kernel.(y, a.p, 1 - a.p) + return x .* y end _testmode!(a::Dropout, test) = (a.active = !test) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index ccd4fe4c..ba80e8a6 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight end """ - binarycrossentropy(ŷ, y) + binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) -Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. +Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability. julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) 3-element Array{Float64,1}: @@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. 0.352317 0.86167 """ -binarycrossentropy(ŷ, y) = -y*log(ŷ) - (1 - y)*log(1 - ŷ) +binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) """ logitbinarycrossentropy(logŷ, y) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index eb9aaa87..810793b6 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,8 @@ module Optimise export train!, - SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad + SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 01c76391..096e2d87 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -99,3 +99,12 @@ tuning. """ AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + +""" + NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other +than learning rate don't need tuning. +""" +NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 3a3c8945..112aaa73 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -35,7 +35,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) acc = zeros(p.x) function () @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= η / (√acc + ϵ) + @. p.Δ *= η / √(acc + ϵ) end end @@ -43,7 +43,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc += p.Δ^2 - @. p.Δ *= η / √acc + @. p.Δ *= η / √(acc + ϵ) end end @@ -64,7 +64,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ function () @. mt = β1 * mt + (1 - β1) * p.Δ @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η + @. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η β1p *= β1 β2p *= β2 end @@ -94,6 +94,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, end end +function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + vt = zeros(p.x) + β1p, β2p = β1, β2 + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. vt = β2 * vt + (1 - β2) * p.Δ^2 + @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η + β1p *= β1 + β2p *= β2 + end +end + clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) function expdecay(p::Param, γ::Real) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 401a1c51..8ad8573e 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(l) && error("Loss is Inf") - isnan(l) && error("Loss is NaN") @interrupts back!(l) opt() cb() == :stop && break diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 8d4a8ca7..1296d179 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing isleaf(x) = !istracked(x) || isleaf(tracker(x)) data(x) = istracked(x) ? data(tracker(x)) : x grad(x) = grad(tracker(x)) +grad(::Void) = nothing struct Call{F,As<:Tuple} func::F @@ -46,11 +47,27 @@ isleaf(x::Tracked) = x.f == Call(nothing) data(x::Tracked) = x.data grad(x::Tracked) = x.grad +function update!(x, Δ) + tracker(x).data += Δ + tracker(x).grad .= 0 + return x +end + include("back.jl") include("scalar.jl") include("array.jl") include("numeric.jl") +""" + hook(f, x) -> x′ + +Hook into gradient backpropagation. `x` is unmodified, but when backpropagating +`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse +the sign of the gradient applied to `x`. +""" +hook(f, x) = istracked(x) ? track(hook, f, x) : x +back(::typeof(hook), Δ, f, x) = @back(x, f(Δ)) + param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index e11296ab..7a54d2eb 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -93,6 +93,26 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) back(xs, Δ′) end + +_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer) +Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer) + +function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer) + Δ′ = similar(xs.data) + Δ′ .= 0 + S = size(xs.data) + + # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ + for (dest_idx, val) in enumerate(IndexCartesian(), Δ) + # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then + # wrap around based on original size S. + src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] + Δ′[src_idx...] += val + end + back(xs, Δ′) +end + + for f in [:vcat, :hcat] @eval begin # This section is a bit of a hack since julia doesn't have a standardised @@ -314,18 +334,18 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) # TODO: can store kwargs efficiently in namedtuples -_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad) +_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation) -conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) -conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) -conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) +conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) +conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) +conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) -function back(::typeof(_conv), Δ, x, w, stride, pad) - @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad)) - @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad)) +function back(::typeof(_conv), Δ, x, w, stride, pad, dilation) + @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) + @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) end _maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride) diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 272f9ba4..755e1f7d 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -1,4 +1,4 @@ -function gradient(f, xs::AbstractArray...) +function gradient(f, xs...) xs = param.(xs) back!(f(xs...)) grad.(xs) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 632046cd..773943c0 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -8,7 +8,11 @@ tracker(x::TrackedReal) = x.tracker track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) -back!(x::TrackedReal) = back!(x, 1) +function back!(x::TrackedReal) + isinf(x) && error("Loss is Inf") + isnan(x) && error("Loss is NaN") + return back!(x, 1) +end function Base.show(io::IO, x::TrackedReal) show(io, data(x)) @@ -19,14 +23,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x)) Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x -Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T = - TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data))) - Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) +Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = + error("Not implemented: convert tracked $S to tracked $T") + Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) +Base.eps(x::TrackedReal) = eps(data(x)) + for f in :[isinf, isnan, isfinite].args @eval Base.$f(x::TrackedReal) = Base.$f(data(x)) end @@ -91,3 +97,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) back(::typeof(getindex), Δ, t, i) = back(t, ntuple(j -> i == j ? Δ : 0, length(t))) + +# Array collection + +function collect(xs) + xs = Base.collect(xs) + track(Call(collect, xs), data.(xs)) +end + +function scan(c::Call{typeof(collect)}) + foreach(scan, c.args[1]) +end + +function back(::typeof(collect), Δ, xs) + foreach((x, Δ) -> @back(x, Δ), xs, Δ) +end diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index ecfa7014..31a67aa7 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,7 +1,9 @@ using Base.Test -using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, +using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, σ, binarycrossentropy, logitbinarycrossentropy +const ϵ = 1e-7 + @testset "losses" begin # First, regression-style y's y = [1, 1, 0, 0] @@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, logŷ, y = randn(3), rand(3) @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ))) end - + @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y) + @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0) end end diff --git a/test/optimise.jl b/test/optimise.jl index ae7ec8fe..c896bb39 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,7 +3,7 @@ using Flux.Tracker @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] + @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt([w′]) diff --git a/test/tracker.jl b/test/tracker.jl index 434148f0..66c08f62 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -1,5 +1,5 @@ using Flux.Tracker, Base.Test, NNlib -using Flux.Tracker: TrackedReal, gradcheck +using Flux.Tracker: TrackedReal, gradcheck, grad using NNlib: conv gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) @@ -114,6 +114,9 @@ end @test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5)) +@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5)) +@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) + @test gradtest(kron, rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3), rand(8)) @test gradtest(kron, rand(5,1), rand(3,1)) @@ -220,4 +223,13 @@ b = param(rand()) Tracker.back!(b) @test Tracker.grad(b) == 1 +@testset "collect" begin + x, y = param(2), param(3) + xy = Tracker.collect([x, y]) + @test xy isa TrackedArray{Float64} + z = xy[1]*xy[2] + back!(z) + @test grad.((x,y)) == (3, 2) +end + end #testset