diff --git a/README.md b/README.md index 0baa74d4..10d611f1 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i ```julia model = Chain( - Dense(768, 128), + Dense(768, 128, σ), LSTM(128, 256) LSTM(256, 128) Dense(128, 10), diff --git a/docs/src/gpu.md b/docs/src/gpu.md index 253904ad..6be2d7b0 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. +(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.) + ```julia using CuArrays diff --git a/docs/src/index.md b/docs/src/index.md index 86c9c3dc..afeb2075 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly ``` Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models. + +See [GPU support](gpu.md) for more details on installing and using Flux with GPUs. diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 02225279..96efc7b8 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -28,13 +28,15 @@ l = loss(x, y) back!(l) ``` -`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. +`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. ```julia -W.grad +using Flux.Tracker: grad, update! -# Update the parameter -W.data .-= 0.1(W.grad) +Δ = grad(W) + +# Update the parameter and reset the gradient +update!(W, -0.1Δ) loss(x, y) # ~ 2.5 ``` diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index 70d06348..cd53544f 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -44,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m)) loss(rand(28^2), rand(10)) ``` + +One can also easily add per-layer regularisation via the `activations` function: + +```julia +julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax) +Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax) + +julia> activations(c, rand(10)) +3-element Array{Any,1}: + param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074]) + param([0.0330606, -0.456104]) + param([0.61991, 0.38009]) + +julia> sum(vecnorm, ans) +2.639678767773633 (tracked) +``` diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 56f511e4..ac58f6d0 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -17,16 +17,17 @@ back!(l) We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: ```julia -function update() +using Flux.Tracker: grad, update! + +function sgd() η = 0.1 # Learning Rate for p in (W, b) - p.data .-= η .* p.grad # Apply the update - p.grad .= 0 # Clear the gradient + update!(p, -η * grad(p)) end end ``` -If we call `update`, the parameters `W` and `b` will change and our loss should go down. +If we call `sgd`, the parameters `W` and `b` will change and our loss should go down. There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum. diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ad374643..cf89df41 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain) print(io, ")") end +# Seem to need this for `accumulate`; try removing on 0.7 +Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any + +activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers) + """ Dense(in::Integer, out::Integer, σ = identity) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c61676aa..38310aad 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,5 +1,10 @@ using NNlib: conv +@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)}) + +expand(N, i::Tuple) = i +expand(N, i::Integer) = ntuple(_ -> i, N) + """ Conv(size, in=>out) Conv(size, in=>out, relu) @@ -21,14 +26,12 @@ struct Conv{N,F,A,V} dilation::NTuple{N,Int} end -Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0, dilation=1) where T = - Conv(σ, w, b, stride, pad, dilation) +Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} = + Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, - stride::NTuple{N,Integer} = map(_->1,k), - pad::NTuple{N,Integer} = map(_->0,k), - dilation::NTuple{N,Integer} = map(_->1,k)) where N = + stride = 1, pad = 0, dilation = 1) where N = Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, stride = stride, pad = pad, dilation = dilation) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index ccd4fe4c..ba80e8a6 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight end """ - binarycrossentropy(ŷ, y) + binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) -Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. +Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability. julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) 3-element Array{Float64,1}: @@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. 0.352317 0.86167 """ -binarycrossentropy(ŷ, y) = -y*log(ŷ) - (1 - y)*log(1 - ŷ) +binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) """ logitbinarycrossentropy(logŷ, y) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 401a1c51..8ad8573e 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(l) && error("Loss is Inf") - isnan(l) && error("Loss is NaN") @interrupts back!(l) opt() cb() == :stop && break diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 8d4a8ca7..1296d179 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing isleaf(x) = !istracked(x) || isleaf(tracker(x)) data(x) = istracked(x) ? data(tracker(x)) : x grad(x) = grad(tracker(x)) +grad(::Void) = nothing struct Call{F,As<:Tuple} func::F @@ -46,11 +47,27 @@ isleaf(x::Tracked) = x.f == Call(nothing) data(x::Tracked) = x.data grad(x::Tracked) = x.grad +function update!(x, Δ) + tracker(x).data += Δ + tracker(x).grad .= 0 + return x +end + include("back.jl") include("scalar.jl") include("array.jl") include("numeric.jl") +""" + hook(f, x) -> x′ + +Hook into gradient backpropagation. `x` is unmodified, but when backpropagating +`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse +the sign of the gradient applied to `x`. +""" +hook(f, x) = istracked(x) ? track(hook, f, x) : x +back(::typeof(hook), Δ, f, x) = @back(x, f(Δ)) + param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 272f9ba4..755e1f7d 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -1,4 +1,4 @@ -function gradient(f, xs::AbstractArray...) +function gradient(f, xs...) xs = param.(xs) back!(f(xs...)) grad.(xs) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 5deaf66c..773943c0 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -8,7 +8,11 @@ tracker(x::TrackedReal) = x.tracker track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) -back!(x::TrackedReal) = back!(x, 1) +function back!(x::TrackedReal) + isinf(x) && error("Loss is Inf") + isnan(x) && error("Loss is NaN") + return back!(x, 1) +end function Base.show(io::IO, x::TrackedReal) show(io, data(x)) @@ -19,15 +23,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x)) Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x -# This cuts derivatives, fix if needed. -# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T = -# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data))) - Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) +Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = + error("Not implemented: convert tracked $S to tracked $T") + Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) +Base.eps(x::TrackedReal) = eps(data(x)) + for f in :[isinf, isnan, isfinite].args @eval Base.$f(x::TrackedReal) = Base.$f(data(x)) end diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index ecfa7014..31a67aa7 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,7 +1,9 @@ using Base.Test -using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, +using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, σ, binarycrossentropy, logitbinarycrossentropy +const ϵ = 1e-7 + @testset "losses" begin # First, regression-style y's y = [1, 1, 0, 0] @@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, logŷ, y = randn(3), rand(3) @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ))) end - + @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y) + @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0) end end