From f6097d58d6416a842509c915c5d07da9e28dc4d6 Mon Sep 17 00:00:00 2001 From: tejank10 Date: Sun, 15 Apr 2018 12:15:41 +0530 Subject: [PATCH 01/22] Scalar pad/stride for Conv constructor --- src/layers/conv.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 994648c2..e8054829 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -24,11 +24,16 @@ Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; stride = 1, pad = 0) where T = Conv(σ, w, b, stride, pad) +#Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, +# stride::NTuple{N,Integer} = map(_->1,k), +# pad::NTuple{N,Integer} = map(_->0,k)) where N = +# Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, +# stride = stride, pad = pad) + Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, - stride::NTuple{N,Integer} = map(_->1,k), - pad::NTuple{N,Integer} = map(_->0,k)) where N = + stride::Integer = 1, pad::Integer = 0) where N = Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, - stride = stride, pad = pad) + stride = map(_->stride,k), pad = map(_->pad,k)) Flux.treelike(Conv) From b080f5c82e080d2aebff2f5f00c5d2612da2e4d9 Mon Sep 17 00:00:00 2001 From: tejank10 Date: Sun, 15 Apr 2018 20:32:40 +0530 Subject: [PATCH 02/22] Scalar pad and stride --- src/layers/conv.jl | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e8054829..addef9fc 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,5 +1,8 @@ using NNlib: conv +expand(::Type{Val{N}}, i::Integer) where N = ntuple(_ -> i, Val{N}) +expand(::Type{Val{N}}, i::NTuple{N, Integer}) where N = i + """ Conv(size, in=>out) Conv(size, in=>out, relu) @@ -24,16 +27,10 @@ Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; stride = 1, pad = 0) where T = Conv(σ, w, b, stride, pad) -#Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, -# stride::NTuple{N,Integer} = map(_->1,k), -# pad::NTuple{N,Integer} = map(_->0,k)) where N = -# Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, -# stride = stride, pad = pad) - Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, - stride::Integer = 1, pad::Integer = 0) where N = + stride = 1, pad = 0) where N = Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, - stride = map(_->stride,k), pad = map(_->pad,k)) + stride = expand(Val{N}, stride), pad = expand(Val{N}, pad)) Flux.treelike(Conv) From 2f5473d4351be3302fdf31a4bdb81810fe05b4f6 Mon Sep 17 00:00:00 2001 From: tejank10 Date: Mon, 16 Apr 2018 00:59:11 +0530 Subject: [PATCH 03/22] added expand in conv constructor --- src/layers/conv.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index addef9fc..9c2ba9c3 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -24,8 +24,9 @@ struct Conv{N,F,A,V} end Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0) where T = - Conv(σ, w, b, stride, pad) + stride = 1, pad = 0) where T = + Conv(σ, w, b, expand(Val{ndims(w) - 2}, stride), + pad = expand(Val{ndims(w) - 2}, pad)) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0) where N = From 2ef25775c678acf2b7ab92dccc2fcb171ba222f7 Mon Sep 17 00:00:00 2001 From: tejank10 Date: Mon, 16 Apr 2018 01:18:26 +0530 Subject: [PATCH 04/22] removed extra expand and fixed bug --- src/layers/conv.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 9c2ba9c3..1ef38a21 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -26,12 +26,12 @@ end Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; stride = 1, pad = 0) where T = Conv(σ, w, b, expand(Val{ndims(w) - 2}, stride), - pad = expand(Val{ndims(w) - 2}, pad)) + expand(Val{ndims(w) - 2}, pad)) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0) where N = - Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, - stride = expand(Val{N}, stride), pad = expand(Val{N}, pad)) + Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ; + stride = stride, pad = pad) Flux.treelike(Conv) From 9fdbe843eff2eb091ccc25c4dfcb74ee6f258eb8 Mon Sep 17 00:00:00 2001 From: "staticfloat@gmail.com" Date: Mon, 7 May 2018 15:30:44 -0700 Subject: [PATCH 05/22] Check for `Inf` and `NaN` within `back!(::TrackedReal)` This is often checked for within user code, no reason to do that, let's do it for them within `back!(::TrackedReal)` --- src/optimise/train.jl | 2 -- src/tracker/scalar.jl | 10 +++++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 401a1c51..8ad8573e 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(l) && error("Loss is Inf") - isnan(l) && error("Loss is NaN") @interrupts back!(l) opt() cb() == :stop && break diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 632046cd..17c35513 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -8,7 +8,15 @@ tracker(x::TrackedReal) = x.tracker track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) -back!(x::TrackedReal) = back!(x, 1) +function back!(x::TrackedReal) + if isinf(x) + error("Loss is Inf") + end + if isnan(x) + error("Loss is NaN") + end + return back!(x, 1) +end function Base.show(io::IO, x::TrackedReal) show(io, data(x)) From ac1448f677d127fcf68867635543f374630bbcb5 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 13 Jun 2018 11:13:48 +0100 Subject: [PATCH 06/22] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0baa74d4..10d611f1 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i ```julia model = Chain( - Dense(768, 128), + Dense(768, 128, σ), LSTM(128, 256) LSTM(256, 128) Dense(128, 10), From aea1e73cdede14a579670a59859d7f2654824e8b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 21 Jun 2018 13:12:42 +0100 Subject: [PATCH 07/22] scalar gradients --- src/tracker/numeric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 272f9ba4..755e1f7d 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -1,4 +1,4 @@ -function gradient(f, xs::AbstractArray...) +function gradient(f, xs...) xs = param.(xs) back!(f(xs...)) grad.(xs) From 7e3cf45ee4e674bc8d1f8d087538477478b5d65f Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Mon, 25 Jun 2018 11:36:52 +0100 Subject: [PATCH 08/22] better error --- src/tracker/scalar.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 5deaf66c..8d0aa29e 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -19,12 +19,11 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x)) Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x -# This cuts derivatives, fix if needed. -# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T = -# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data))) - Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) +Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = + error("Not implemented: convert tracked $S to tracked $T") + Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) From 7726a5b605832b0fc35e26332a0f0f83a5d5f210 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 26 Jun 2018 14:05:07 +0100 Subject: [PATCH 09/22] inferrable --- src/layers/conv.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 7548fc96..38310aad 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,7 +1,9 @@ using NNlib: conv -expand(::Type{Val{N}}, i::Integer) where N = ntuple(_ -> i, Val{N}) -expand(::Type{Val{N}}, i::NTuple{N, Integer}) where N = i +@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)}) + +expand(N, i::Tuple) = i +expand(N, i::Integer) = ntuple(_ -> i, N) """ Conv(size, in=>out) @@ -24,9 +26,9 @@ struct Conv{N,F,A,V} dilation::NTuple{N,Int} end -Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0, dilation = 1) where T = - Conv(σ, w, b, expand.(Val{ndims(w)-2}, (stride, pad, dilation))...) +Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} = + Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...) Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, stride = 1, pad = 0, dilation = 1) where N = From aa8f79f10ce1c4a9daccab0c91a5859c838c5a30 Mon Sep 17 00:00:00 2001 From: Kade Date: Thu, 19 Apr 2018 07:48:30 -0500 Subject: [PATCH 10/22] Mention CUDAnative.jl's install instructions --- README.md | 2 ++ docs/src/gpu.md | 2 ++ docs/src/index.md | 2 ++ 3 files changed, 6 insertions(+) diff --git a/README.md b/README.md index f8e301ed..cbd3633e 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. +You need to build Julia 0.6 from source and have CUDA available to use Flux with GPUs – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details. + ```julia julia> Pkg.add("Flux") ``` diff --git a/docs/src/gpu.md b/docs/src/gpu.md index 253904ad..8fd36f98 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -1,5 +1,7 @@ # GPU Support +You need to build Julia 0.6 from source and have CUDA available to use Flux with GPUs – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details. + Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it. For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. diff --git a/docs/src/index.md b/docs/src/index.md index 86c9c3dc..afeb2075 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly ``` Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models. + +See [GPU support](gpu.md) for more details on installing and using Flux with GPUs. From 1490d87d8387a078a29a336cb37fd7269240179e Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 26 Jun 2018 14:25:24 +0100 Subject: [PATCH 11/22] tweaks --- README.md | 2 -- docs/src/gpu.md | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index cbd3633e..f8e301ed 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,6 @@ Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable. -You need to build Julia 0.6 from source and have CUDA available to use Flux with GPUs – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details. - ```julia julia> Pkg.add("Flux") ``` diff --git a/docs/src/gpu.md b/docs/src/gpu.md index 8fd36f98..6be2d7b0 100644 --- a/docs/src/gpu.md +++ b/docs/src/gpu.md @@ -1,11 +1,11 @@ # GPU Support -You need to build Julia 0.6 from source and have CUDA available to use Flux with GPUs – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details. - Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it. For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. +(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.) + ```julia using CuArrays From 0a04e3ba61b0cb951b41184971f30d6c41dbf510 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 26 Jun 2018 14:30:46 +0100 Subject: [PATCH 12/22] Chain `activations` --- src/layers/basic.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ad374643..cf89df41 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain) print(io, ")") end +# Seem to need this for `accumulate`; try removing on 0.7 +Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any + +activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers) + """ Dense(in::Integer, out::Integer, σ = identity) From d6a75e1289488945ade1cfa8867717fe1ed19557 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 26 Jun 2018 14:35:03 +0100 Subject: [PATCH 13/22] add `activations` docs --- docs/src/models/regularisation.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/src/models/regularisation.md b/docs/src/models/regularisation.md index 70d06348..cd53544f 100644 --- a/docs/src/models/regularisation.md +++ b/docs/src/models/regularisation.md @@ -44,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m)) loss(rand(28^2), rand(10)) ``` + +One can also easily add per-layer regularisation via the `activations` function: + +```julia +julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax) +Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax) + +julia> activations(c, rand(10)) +3-element Array{Any,1}: + param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074]) + param([0.0330606, -0.456104]) + param([0.61991, 0.38009]) + +julia> sum(vecnorm, ans) +2.639678767773633 (tracked) +``` From 836e3872b69d70e8c1c073617da25790f1ceca14 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 26 Jun 2018 15:09:21 +0100 Subject: [PATCH 14/22] style --- src/tracker/scalar.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 5b6cfb57..f94f1647 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -9,12 +9,8 @@ tracker(x::TrackedReal) = x.tracker track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) function back!(x::TrackedReal) - if isinf(x) - error("Loss is Inf") - end - if isnan(x) - error("Loss is NaN") - end + isinf(x) && error("Loss is Inf") + isnan(x) && error("Loss is NaN") return back!(x, 1) end From 88c16e62dd2b764bc3a8a44a536090895c10269b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 26 Jun 2018 14:57:43 +0100 Subject: [PATCH 15/22] fixes #284 --- src/tracker/Tracker.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 8d4a8ca7..1761d5fd 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing isleaf(x) = !istracked(x) || isleaf(tracker(x)) data(x) = istracked(x) ? data(tracker(x)) : x grad(x) = grad(tracker(x)) +grad(::Void) = nothing struct Call{F,As<:Tuple} func::F From bed6d2311e4c9778c8057249ccce9a636f334cc0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 26 Jun 2018 16:07:58 +0100 Subject: [PATCH 16/22] clearer docs --- docs/src/models/basics.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 02225279..d99a5426 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -28,13 +28,15 @@ l = loss(x, y) back!(l) ``` -`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. +`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. ```julia W.grad # Update the parameter W.data .-= 0.1(W.grad) +# Reset the gradient +W.grad .= 0 loss(x, y) # ~ 2.5 ``` From e08fd7a6d2323153a304d211f9b97c6c9c78074f Mon Sep 17 00:00:00 2001 From: Matthew Kelley Date: Tue, 26 Jun 2018 11:43:16 -0600 Subject: [PATCH 17/22] Added epsilon term to binarycrossentropy --- src/layers/stateless.jl | 6 +++--- test/layers/stateless.jl | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index ccd4fe4c..6fe28d30 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight end """ - binarycrossentropy(ŷ, y) + binarycrossentropy(ŷ, y; ϵ) -Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. +Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability. julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) 3-element Array{Float64,1}: @@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. 0.352317 0.86167 """ -binarycrossentropy(ŷ, y) = -y*log(ŷ) - (1 - y)*log(1 - ŷ) +binarycrossentropy(ŷ, y; ϵ=1e-7) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) """ logitbinarycrossentropy(logŷ, y) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index ecfa7014..7c641261 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,7 +1,9 @@ using Base.Test -using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, +using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, σ, binarycrossentropy, logitbinarycrossentropy +const ϵ = 1e-7 + @testset "losses" begin # First, regression-style y's y = [1, 1, 0, 0] @@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, logŷ, y = randn(3), rand(3) @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) + 1e-7) - (1 - y).*log.(1 - σ.(logŷ) + 1e-7) end - + @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y) + @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0) end end From ed032cdb1ed565509dd60b580ed0b68b0a2b422f Mon Sep 17 00:00:00 2001 From: Matthew Kelley Date: Tue, 26 Jun 2018 12:29:06 -0600 Subject: [PATCH 18/22] =?UTF-8?q?Change=20epsilon=20value=20to=20eps(y?= =?UTF-8?q?=CC=82)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/layers/stateless.jl | 4 ++-- test/layers/stateless.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 6fe28d30..ba80e8a6 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -15,7 +15,7 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight end """ - binarycrossentropy(ŷ, y; ϵ) + binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability. @@ -25,7 +25,7 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica 0.352317 0.86167 """ -binarycrossentropy(ŷ, y; ϵ=1e-7) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) +binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) """ logitbinarycrossentropy(logŷ, y) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index 7c641261..91d530ca 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -43,7 +43,7 @@ const ϵ = 1e-7 logŷ, y = randn(3), rand(3) @testset "binarycrossentropy" begin @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) - @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) + 1e-7) - (1 - y).*log.(1 - σ.(logŷ) + 1e-7) + @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(ŷ)) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(ŷ)) end @testset "logitbinarycrossentropy" begin From 0e95be33269f0582aafd0de716b60793cdb8de35 Mon Sep 17 00:00:00 2001 From: Matthew Kelley Date: Tue, 26 Jun 2018 14:48:51 -0600 Subject: [PATCH 19/22] =?UTF-8?q?Call=20Flux.Tracker.data()=20on=20y=CC=82?= =?UTF-8?q?=20for=20bce?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/layers/stateless.jl | 3 ++- test/layers/stateless.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index ba80e8a6..7d3eb8d0 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -25,7 +25,8 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica 0.352317 0.86167 """ -binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) +# binarycrossentropy(ŷ, y; ϵ=eps(Flux.Tracker.data(ŷ))) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) +binarycrossentropy(ŷ, y; ϵ=eps(Flux.Tracker.data(ŷ))) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) """ logitbinarycrossentropy(logŷ, y) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index 91d530ca..31a67aa7 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -43,7 +43,7 @@ const ϵ = 1e-7 logŷ, y = randn(3), rand(3) @testset "binarycrossentropy" begin @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) - @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(ŷ)) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(ŷ)) + @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ))) end @testset "logitbinarycrossentropy" begin From 864d72eef5e0488274e4fd5e8acd4f253d112112 Mon Sep 17 00:00:00 2001 From: Matthew Kelley Date: Tue, 26 Jun 2018 23:55:43 -0600 Subject: [PATCH 20/22] Overload Base.eps() for TrackedReal --- src/layers/stateless.jl | 3 +-- src/tracker/scalar.jl | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 7d3eb8d0..ba80e8a6 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -25,8 +25,7 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica 0.352317 0.86167 """ -# binarycrossentropy(ŷ, y; ϵ=eps(Flux.Tracker.data(ŷ))) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) -binarycrossentropy(ŷ, y; ϵ=eps(Flux.Tracker.data(ŷ))) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) +binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ) """ logitbinarycrossentropy(logŷ, y) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index f94f1647..773943c0 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -31,6 +31,8 @@ Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} = Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) +Base.eps(x::TrackedReal) = eps(data(x)) + for f in :[isinf, isnan, isfinite].args @eval Base.$f(x::TrackedReal) = Base.$f(data(x)) end From 5d8b63dc65fe6bf72109e10217e51c72dbce75f7 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Fri, 29 Jun 2018 13:53:50 +0100 Subject: [PATCH 21/22] avoid implementation details in docs --- docs/src/models/basics.md | 10 +++++----- docs/src/training/optimisers.md | 9 +++++---- src/tracker/Tracker.jl | 6 ++++++ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index d99a5426..96efc7b8 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -31,12 +31,12 @@ back!(l) `loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. ```julia -W.grad +using Flux.Tracker: grad, update! -# Update the parameter -W.data .-= 0.1(W.grad) -# Reset the gradient -W.grad .= 0 +Δ = grad(W) + +# Update the parameter and reset the gradient +update!(W, -0.1Δ) loss(x, y) # ~ 2.5 ``` diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 56f511e4..ac58f6d0 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -17,16 +17,17 @@ back!(l) We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: ```julia -function update() +using Flux.Tracker: grad, update! + +function sgd() η = 0.1 # Learning Rate for p in (W, b) - p.data .-= η .* p.grad # Apply the update - p.grad .= 0 # Clear the gradient + update!(p, -η * grad(p)) end end ``` -If we call `update`, the parameters `W` and `b` will change and our loss should go down. +If we call `sgd`, the parameters `W` and `b` will change and our loss should go down. There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum. diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 1761d5fd..7b1fdce7 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -47,6 +47,12 @@ isleaf(x::Tracked) = x.f == Call(nothing) data(x::Tracked) = x.data grad(x::Tracked) = x.grad +function update!(x, Δ) + tracker(x).data += Δ + tracker(x).grad .= 0 + return x +end + include("back.jl") include("scalar.jl") include("array.jl") From ce88273880730990ef2e236b775b2080eca12f4a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 2 Jul 2018 13:17:46 +0100 Subject: [PATCH 22/22] gradient hook --- src/tracker/Tracker.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 7b1fdce7..1296d179 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -58,6 +58,16 @@ include("scalar.jl") include("array.jl") include("numeric.jl") +""" + hook(f, x) -> x′ + +Hook into gradient backpropagation. `x` is unmodified, but when backpropagating +`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse +the sign of the gradient applied to `x`. +""" +hook(f, x) = istracked(x) ? track(hook, f, x) : x +back(::typeof(hook), Δ, f, x) = @back(x, f(Δ)) + param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs))