diff --git a/README.md b/README.md index 4785c55c..0baa74d4 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@

-[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) +[![Build Status](https://travis-ci.org/FluxML/Flux.jl.svg?branch=master)](https://travis-ci.org/FluxML/Flux.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://fluxml.github.io/Flux.jl/stable/) [![](https://img.shields.io/badge/chat-on%20slack-yellow.svg)](https://slackinvite.julialang.org/) [![DOI](http://joss.theoj.org/papers/10.21105/joss.00602/status.svg)](https://doi.org/10.21105/joss.00602) Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. @@ -12,6 +12,18 @@ julia> Pkg.add("Flux") See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. +If you use Flux in research, please cite the following paper: + +``` +@article{innes:2018, + author = {Mike Innes}, + title = {Flux: Elegant Machine Learning with Julia}, + journal = {Journal of Open Source Software}, + year = {2018}, + doi = {10.21105/joss.00602}, +} +``` + ## Features Flux has powerful high-level features, and common architectures can be defined in a few lines. @@ -79,3 +91,9 @@ For general questions and help, check out Julia's [community forum](https://disc Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here. For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel. + +## Related Packages + +Check out [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) for common computer vision datasets and trained models. + +[MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) provides further common datasets. diff --git a/docs/make.jl b/docs/make.jl index d7f14d8e..ed6a8c8b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -18,6 +18,8 @@ makedocs(modules=[Flux, NNlib], "One-Hot Encoding" => "data/onehot.md", "GPU Support" => "gpu.md", "Saving & Loading" => "saving.md", + "Internals" => + ["Backpropagation" => "internals/tracker.md"], "Community" => "community.md"]) deploydocs( diff --git a/docs/src/internals/tracker.md b/docs/src/internals/tracker.md new file mode 100644 index 00000000..b9addc34 --- /dev/null +++ b/docs/src/internals/tracker.md @@ -0,0 +1,156 @@ +# Flux.Tracker + +Backpropagation, or reverse-mode automatic differentiation, is handled by the `Flux.Tracker` module. + +```julia +julia> using Flux.Tracker +``` + +The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters: + +```julia +julia> W = param([1 2; 3 4]) +Tracked 2×2 Array{Float64,2}: + 1.0 2.0 + 3.0 4.0 + +julia> x = param([5, 6]) +Tracked 2-element Array{Float64,1}: + 5.0 + 6.0 + +julia> y = W*x +Tracked 2-element Array{Float64,1}: + 17.0 + 39.0 +``` + +The output `y` is also a `TrackedArray` object. We can now backpropagate sensitivities to `W` and `x` via the `back!` function, and see the gradients accumulated in the `W` and `x` tracked arrays: + +```julia +julia> Tracker.back!(y, [1, -1]) + +julia> W.grad +2×2 Array{Float64,2}: + 5.0 6.0 +-5.0 -6.0 + +julia> x.grad +2-element Array{Float64,1}: + -2.0 + -2.0 +``` + +## Internals + +All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field. + +```julia +julia> x.tracker +Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0]) +``` + +The `Tracker` stores the value and gradient of a given object, which we've seen before. + +```julia +julia> x.tracker.data +2-element Array{Float64,1}: + 5.0 + 6.0 + +julia> x.tracker.grad +2-element Array{Float64,1}: + -2.0 + -2.0 +``` + +The tracker also contains a `Call` object, which simply represents a function call that was made at some point during the forward pass. For example, the `+` call would look like this: + +```julia +julia> Tracker.Call(+, 1, 2) +Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2)) +``` + +In the case of the `y` we produced above, we can see that it stores the call that produced it -- that is, `W*x`. + +```julia +julia> y.tracker.f +Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0]))) +``` + +Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that `Tracker` ends up forming a data structure that records everything that happened during the forward pass (often known as a *tape*). + +When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forwarded to `y`'s call (`*`), effectively calling + +```julia +Tracker.back(*, [1, -1], W, x) +``` + +which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters. + +## Custom Gradients + +We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`: + +```julia +julia> minus(a, b) = a - b +``` + +Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: + +```julia +julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus (generic function with 2 methods) +``` + +`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened: + +```julia +julia> a, b = param([6,5,4]), param([1,2,3]) +(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])) + +julia> c = minus(a, b) +Tracked 3-element Array{Float64,1}: + 5.0 + 3.0 + 1.0 + +julia> c.tracker.f +Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))) +``` + +Finally, we have to specify the gradient of `minus`. + +```julia +julia> Tracker.back(::typeof(minus), Δ, a, b) = + (Tracker.@back(a, Δ); Tracker.@back(b, -Δ)) +``` + +`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`. + +```julia +julia> Flux.back!(c, 1) + +julia> a.grad +3-element Array{Float64,1}: + 1.0 + 1.0 + 1.0 + +julia> b.grad +3-element Array{Float64,1}: + -1.0 + -1.0 + -1.0 +``` + +## Notes + +For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: + +```julia +minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b) +``` + +`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op. diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 379268b3..c2056bb4 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -5,7 +5,7 @@ These core layers form the foundation of almost all neural networks. ```@docs Chain Dense -Conv2D +Conv ``` ## Recurrent Layers diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index d4325a53..70d06348 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -7,6 +7,7 @@ add the result to the overall loss. For example, say we have a simple regression. ```julia +using Flux: crossentropy m = Dense(10, 5) loss(x, y) = crossentropy(softmax(m(x)), y) ``` diff --git a/src/Flux.jl b/src/Flux.jl index 7746ecff..7d1d66e6 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,22 +7,23 @@ module Flux using Juno, Requires, Reexport using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D, - Dropout, LayerNorm, BatchNorm, - SGD, ADAM, Momentum, Nesterov, AMSGrad, - param, params, mapleaves, cpu, gpu +export Chain, Dense, RNN, LSTM, GRU, Conv, + Dropout, LayerNorm, BatchNorm, + params, mapleaves, cpu, gpu @reexport using NNlib using NNlib: @fix include("tracker/Tracker.jl") using .Tracker -export Tracker -import .Tracker: data +using .Tracker: data +export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs +export SGD, ADAM, AdaMax, Momentum, Nesterov, + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM include("utils.jl") include("onehot.jl") diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 994648c2..c61676aa 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -10,7 +10,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`. Data should be stored in WHCN order. In other words, a 100×100 RGB image would be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. -Takes the keyword arguments `pad` and `stride`. +Takes the keyword arguments `pad`, `stride` and `dilation`. """ struct Conv{N,F,A,V} σ::F @@ -18,17 +18,19 @@ struct Conv{N,F,A,V} bias::V stride::NTuple{N,Int} pad::NTuple{N,Int} + dilation::NTuple{N,Int} end Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0) where T = - Conv(σ, w, b, stride, pad) + stride = 1, pad = 0, dilation=1) where T = + Conv(σ, w, b, stride, pad, dilation) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride::NTuple{N,Integer} = map(_->1,k), - pad::NTuple{N,Integer} = map(_->0,k)) where N = + pad::NTuple{N,Integer} = map(_->0,k), + dilation::NTuple{N,Integer} = map(_->1,k)) where N = Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, - stride = stride, pad = pad) + stride = stride, pad = pad, dilation = dilation) Flux.treelike(Conv) @@ -36,7 +38,7 @@ function (c::Conv)(x) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) + σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) end function Base.show(io::IO, l::Conv) @@ -45,6 +47,3 @@ function Base.show(io::IO, l::Conv) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end - -# v0.5 -@deprecate Conv2D(args...; kw...) Conv(args...; kw...) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 74905a36..54f5eb56 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -31,15 +31,14 @@ function Dropout(p) Dropout{typeof(p)}(p, true) end +_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) + function (a::Dropout)(x) a.active || return x y = similar(x) rand!(y) - q = 1 - a.p - @inbounds for i=1:length(y) - y[i] = y[i] > a.p ? 1 / q : 0 - end - return y .* x + y .= _dropout_kernel.(y, a.p, 1 - a.p) + return x .* y end _testmode!(a::Dropout, test) = (a.active = !test) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index acec542e..0c541b93 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,8 @@ module Optimise -export update!, params, train!, - SGD, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad +export train!, + SGD, ADAM, AdaMax, Momentum, Nesterov, + RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM struct Param{T} x::T diff --git a/src/optimise/interface.jl b/src/optimise/interface.jl index 42b05dc8..3a07f6ce 100644 --- a/src/optimise/interface.jl +++ b/src/optimise/interface.jl @@ -56,6 +56,15 @@ RMSProp(ps, η = 0.001; ρ = 0.9, ϵ = 1e-8, decay = 0) = ADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p->adam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) +""" + AdaMax(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[AdaMax](https://arxiv.org/abs/1412.6980v9) optimiser. Variant of ADAM based on +the ∞-norm. +""" +AdaMax(ps, η = 0.002; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->adamax(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) + """ ADAGrad(params, η = 0.01; ϵ = 1e-8, decay = 0) @@ -82,3 +91,12 @@ tuning. """ AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) + +""" + NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) + +[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other +than learning rate don't need tuning. +""" +NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = + optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1)) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index c09e6131..e3a4ed34 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -27,7 +27,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8) acc = zeros(p.x) function () @. acc = ρ * acc + (1 - ρ) * p.Δ^2 - @. p.Δ *= η / (√acc + ϵ) + @. p.Δ *= η / √(acc + ϵ) end end @@ -35,7 +35,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8) acc = zeros(p.x) .+ ϵ function () @. acc += p.Δ^2 - @. p.Δ *= η / √acc + @. p.Δ *= η / √(acc + ϵ) end end @@ -56,12 +56,24 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ function () @. mt = β1 * mt + (1 - β1) * p.Δ @. vt = β2 * vt + (1 - β2) * p.Δ^2 - @. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η + @. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η β1p *= β1 β2p *= β2 end end +function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + ut = zeros(p.x) + β1p = β1 + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. ut = max(β2 * ut, abs(p.Δ)) + @. p.Δ = (η/(1 - β1p)) * mt/(ut + ϵ) + β1p *= β1 + end +end + function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) mt = zeros(p.x) vt = zeros(p.x) .+ ϵ @@ -74,6 +86,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, end end +function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8) + mt = zeros(p.x) + vt = zeros(p.x) + β1p, β2p = β1, β2 + function () + @. mt = β1 * mt + (1 - β1) * p.Δ + @. vt = β2 * vt + (1 - β2) * p.Δ^2 + @. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η + β1p *= β1 + β2p *= β2 + end +end + clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) function expdecay(p::Param, γ::Real) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index bb55ef73..7a54d2eb 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -41,7 +41,7 @@ end Base.setindex!(xs::TrackedArray, v, i...) = error("Can't differentiate `setindex!`") -back!(::TrackedArray) = error("Use back!(x, Δ)") +back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`") # Fallthrough methods @@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b) -Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...) -Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b) -Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b) - -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...) -Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b) -Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) - -Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b) -Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) -Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) -Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) - function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) S = size(xs.data) @@ -108,15 +93,90 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) back(xs, Δ′) end + +_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer) +Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer) + +function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer) + Δ′ = similar(xs.data) + Δ′ .= 0 + S = size(xs.data) + + # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ + for (dest_idx, val) in enumerate(IndexCartesian(), Δ) + # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then + # wrap around based on original size S. + src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] + Δ′[src_idx...] += val + end + back(xs, Δ′) +end + + +for f in [:vcat, :hcat] + @eval begin + # This section is a bit of a hack since julia doesn't have a standardised + # promotion mechanism for concatenation yet + # https://github.com/JuliaLang/julia/pull/20815 + + # It should support tracked concatenation with rank ∈ (1,2) with a + # TrackedArray anywhere among the arguments This works as long as base has + # other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. + Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...) + + # It should support tracked concatenation with rank>2 if the TrackedArray is + # first + Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) + Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row + + # It should support tracked concatenation with rank>2 if the TrackedArray is + # second + Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) + Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, + c::Union{TrackedArray,Vector,RowVector,Matrix}...) = + track($f, a, b, c...) # resolves ambiguity introduced by previous row + end +end + function back(::typeof(vcat), Δ, xs...) - i = Base.tail(map(_ -> :, size(Δ))) start = 0 for xsi in xs + i = map(_ -> :, size(xsi)) |> Base.tail @back(xsi, Δ[start+1:start+size(xsi,1), i...]) start += size(xsi, 1) end end +function back(::typeof(hcat), Δ, xs...) + start = 0 + for xsi in xs + if ndims(xsi) == 1 + @back(xsi, Δ[:, start+1]) + else + i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail + @back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) + end + start += size(xsi, 2) + end +end + +Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) +Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) + +function back(::typeof(cat), Δ, dims, Xs...) + start = ntuple(i -> 0, Val{ndims(Δ)}) + for xs in Xs + dim_xs = 1:ndims(xs) + till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)}) + + xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)}) + + @back(xs, reshape(Δ[xs_in_Δ...],size(xs))) + + start = start .+ till_xs + end +end + Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) @@ -156,12 +216,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ) back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ) -Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = track(mean, xs) Base.mean(xs::TrackedArray, region) = track(mean, xs, region) +Base.maximum(xs::TrackedArray) = track(maximum, xs) +Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region) +Base.minimum(xs::TrackedArray) = track(minimum, xs) +Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region) + LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) @@ -184,6 +248,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ back(::typeof(mean), Δ, xs::TrackedArray, region) = back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) +function back(::typeof(maximum), Δ, xs::TrackedArray) + Δ′ = zeros(xs.data) + _, i = findmax(xs.data) + Δ′[i] = Δ + @back(xs, Δ′) +end +function back(::typeof(maximum), Δ, xs::TrackedArray, region) + Δ′ = zeros(xs.data) + _, is = findmax(xs.data, region) + Δ′[is] = Δ + @back(xs, Δ′) +end +function back(::typeof(minimum), Δ, xs::TrackedArray) + Δ′ = zeros(xs.data) + _, i = findmin(xs.data) + Δ′[i] = Δ + @back(xs, Δ′) +end +function back(::typeof(minimum), Δ, xs::TrackedArray, region) + Δ′ = zeros(xs.data) + _, is = findmin(xs.data, region) + Δ′[is] = Δ + @back(xs, Δ′) +end + # BLAS Base.diagm(x::TrackedVector) = track(diagm, x) @@ -245,18 +334,18 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) # TODO: can store kwargs efficiently in namedtuples -_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad) +_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation) -conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) -conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) -conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N = - track(_conv, x, w, stride, pad) +conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) +conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) +conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = + track(_conv, x, w, stride, pad, dilation) -function back(::typeof(_conv), Δ, x, w, stride, pad) - @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad)) - @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad)) +function back(::typeof(_conv), Δ, x, w, stride, pad, dilation) + @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) + @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) end _maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 632046cd..5deaf66c 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -19,8 +19,9 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x)) Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x -Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T = - TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data))) +# This cuts derivatives, fix if needed. +# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T = +# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data))) Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) @@ -91,3 +92,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) back(::typeof(getindex), Δ, t, i) = back(t, ntuple(j -> i == j ? Δ : 0, length(t))) + +# Array collection + +function collect(xs) + xs = Base.collect(xs) + track(Call(collect, xs), data.(xs)) +end + +function scan(c::Call{typeof(collect)}) + foreach(scan, c.args[1]) +end + +function back(::typeof(collect), Δ, xs) + foreach((x, Δ) -> @back(x, Δ), xs, Δ) +end diff --git a/test/optimise.jl b/test/optimise.jl index d57e4985..c896bb39 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -3,7 +3,7 @@ using Flux.Tracker @testset "Optimise" begin w = randn(10, 10) - @testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] + @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM] w′ = param(randn(10, 10)) loss(x) = Flux.mse(w*x, w′*x) opt = Opt([w′]) diff --git a/test/tracker.jl b/test/tracker.jl index 0f5b6189..66c08f62 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -1,5 +1,5 @@ using Flux.Tracker, Base.Test, NNlib -using Flux.Tracker: TrackedReal, gradcheck +using Flux.Tracker: TrackedReal, gradcheck, grad using NNlib: conv gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) @@ -29,17 +29,97 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) -@test gradtest(vcat, rand(5), rand(3)) -@test gradtest(vcat, rand(5), rand(3), rand(8)) -@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) +function promotiontest(f, A, B, C) + r0 = f(A, B, C) + r1 = f(param(A), B, C) + r2 = f(A, param(B), C) + if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat] + r3 = f(A, B, param(C)) + else + @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved + r3 = r2 + end + r4 = f(param(A), param(B), param(C)) + + @test !isa(r0, TrackedArray) + @test all(isa.([r1,r2,r3,r4], TrackedArray)) + @test r1 == r2 == r3 == r4 + @test r0 == Flux.data(r4) +end + +@testset "concat" begin + cat1(x...) = cat(1, x...) + cat2(x...) = cat(2, x...) + + @testset for vcatf in [vcat, cat1] + @test gradtest(vcatf, rand(5), rand(3)) + @test gradtest(vcatf, rand(5), rand(3), rand(8)) + @test gradtest(vcatf, rand(5)', rand(5)') + @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) + @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) + @test gradtest(vcatf, rand(5), rand(3,1)) + @test gradtest(vcatf, rand(5)', rand(2,5)) + end + + @testset for hcatf in [hcat, cat2] + @test gradtest(hcatf, rand(5), rand(5)) + @test gradtest(hcatf, rand(5)', rand(5)') + @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) + @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) + @test gradtest(hcatf, rand(5), rand(5), rand(5,2)) + @test gradtest(hcatf, rand(5)', rand(1,3)) + @test gradtest(hcatf, rand(5), rand(5,2)) +end + + @testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] + @test gradtest(catf, rand(5)) + @test gradtest(catf, rand(5)') + @test gradtest(catf, rand(2,5)) + @test gradtest(catf, rand(2,5,3)) + end + + @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) + + @testset "cat($dim, ...)" for dim in 3:5 + catdim = (x...) -> cat(dim, x...) + @test gradtest(catdim, rand(5), rand(5), rand(5)) + @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) + @test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3)) + end + + @test !isa(vcat(rand(2)), TrackedArray) + @test !isa(hcat(rand(2)), TrackedArray) + @test !isa(cat(1,rand(2)), TrackedArray) + + @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) + + @testset "promotiontest" begin + @testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)] + promotiontest(fcat, rand(2), rand(2), rand(2)) + promotiontest(fcat, rand(2)', rand(2)', rand(2)') + promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2)) + promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2)) + end + + promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) + promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) + promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) + promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) + promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2)) + end +end + @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5)) -@test gradtest(kron,rand(5), rand(3)) +@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5)) +@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) + +@test gradtest(kron, rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3), rand(8)) -@test gradtest(kron,rand(5,1), rand(3,1)) +@test gradtest(kron, rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) @@ -55,6 +135,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) end +@testset "maximum" begin + @test gradtest(maximum, rand(2, 3)) + + @test gradtest(x -> maximum(x, 1), rand(2, 3)) + @test gradtest(x -> maximum(x, 2), rand(2, 3)) + @test gradtest(x -> maximum(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4)) +end + +@testset "minimum" begin + @test gradtest(minimum, rand(2, 3)) + + @test gradtest(x -> minimum(x, 1), rand(2, 3)) + @test gradtest(x -> minimum(x, 2), rand(2, 3)) + @test gradtest(x -> minimum(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4)) +end + @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5)) @@ -123,4 +223,13 @@ b = param(rand()) Tracker.back!(b) @test Tracker.grad(b) == 1 +@testset "collect" begin + x, y = param(2), param(3) + xy = Tracker.collect([x, y]) + @test xy isa TrackedArray{Float64} + z = xy[1]*xy[2] + back!(z) + @test grad.((x,y)) == (3, 2) +end + end #testset