This commit is contained in:
Avik Pal 2018-07-04 07:31:45 +05:30
commit c38d4edef7
14 changed files with 84 additions and 30 deletions

View File

@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i
```julia ```julia
model = Chain( model = Chain(
Dense(768, 128), Dense(768, 128, σ),
LSTM(128, 256) LSTM(128, 256)
LSTM(256, 128) LSTM(256, 128)
Dense(128, 10), Dense(128, 10),

View File

@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.)
```julia ```julia
using CuArrays using CuArrays

View File

@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly
``` ```
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models. Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.

View File

@ -28,13 +28,15 @@ l = loss(x, y)
back!(l) back!(l)
``` ```
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. `loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
```julia ```julia
W.grad using Flux.Tracker: grad, update!
# Update the parameter Δ = grad(W)
W.data .-= 0.1(W.grad)
# Update the parameter and reset the gradient
update!(W, -0.1Δ)
loss(x, y) # ~ 2.5 loss(x, y) # ~ 2.5
``` ```

View File

@ -44,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
loss(rand(28^2), rand(10)) loss(rand(28^2), rand(10))
``` ```
One can also easily add per-layer regularisation via the `activations` function:
```julia
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
julia> activations(c, rand(10))
3-element Array{Any,1}:
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
param([0.0330606, -0.456104])
param([0.61991, 0.38009])
julia> sum(vecnorm, ans)
2.639678767773633 (tracked)
```

View File

@ -17,16 +17,17 @@ back!(l)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia ```julia
function update() using Flux.Tracker: grad, update!
function sgd()
η = 0.1 # Learning Rate η = 0.1 # Learning Rate
for p in (W, b) for p in (W, b)
p.data .-= η .* p.grad # Apply the update update!(p, -η * grad(p))
p.grad .= 0 # Clear the gradient
end end
end end
``` ```
If we call `update`, the parameters `W` and `b` will change and our loss should go down. If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum. There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.

View File

@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain)
print(io, ")") print(io, ")")
end end
# Seem to need this for `accumulate`; try removing on 0.7
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
""" """
Dense(in::Integer, out::Integer, σ = identity) Dense(in::Integer, out::Integer, σ = identity)

View File

@ -1,5 +1,10 @@
using NNlib: conv using NNlib: conv
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
""" """
Conv(size, in=>out) Conv(size, in=>out)
Conv(size, in=>out, relu) Conv(size, in=>out, relu)
@ -21,14 +26,12 @@ struct Conv{N,F,A,V}
dilation::NTuple{N,Int} dilation::NTuple{N,Int}
end end
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation=1) where T = stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, stride, pad, dilation) Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k), stride = 1, pad = 0, dilation = 1) where N =
pad::NTuple{N,Integer} = map(_->0,k),
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation) stride = stride, pad = pad, dilation = dilation)

View File

@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
end end
""" """
binarycrossentropy(, y) binarycrossentropy(, y; ϵ=eps())
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}: 3-element Array{Float64,1}:
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
0.352317 0.352317
0.86167 0.86167
""" """
binarycrossentropy(, y) = -y*log() - (1 - y)*log(1 - ) binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
""" """
logitbinarycrossentropy(logŷ, y) logitbinarycrossentropy(logŷ, y)

View File

@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
l = loss(d...) l = loss(d...)
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
@interrupts back!(l) @interrupts back!(l)
opt() opt()
cb() == :stop && break cb() == :stop && break

View File

@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x)) isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x)) grad(x) = grad(tracker(x))
grad(::Void) = nothing
struct Call{F,As<:Tuple} struct Call{F,As<:Tuple}
func::F func::F
@ -46,11 +47,27 @@ isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data data(x::Tracked) = x.data
grad(x::Tracked) = x.grad grad(x::Tracked) = x.grad
function update!(x, Δ)
tracker(x).data += Δ
tracker(x).grad .= 0
return x
end
include("back.jl") include("back.jl")
include("scalar.jl") include("scalar.jl")
include("array.jl") include("array.jl")
include("numeric.jl") include("numeric.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
param(x::Number) = TrackedReal(float(x)) param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs)) param(xs::AbstractArray) = TrackedArray(float.(xs))

View File

@ -1,4 +1,4 @@
function gradient(f, xs::AbstractArray...) function gradient(f, xs...)
xs = param.(xs) xs = param.(xs)
back!(f(xs...)) back!(f(xs...))
grad.(xs) grad.(xs)

View File

@ -8,7 +8,11 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
back!(x::TrackedReal) = back!(x, 1) function back!(x::TrackedReal)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
end
function Base.show(io::IO, x::TrackedReal) function Base.show(io::IO, x::TrackedReal)
show(io, data(x)) show(io, data(x))
@ -19,15 +23,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
# This cuts derivatives, fix if needed.
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
Base.eps(x::TrackedReal) = eps(data(x))
for f in :[isinf, isnan, isfinite].args for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x)) @eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end end

View File

@ -1,7 +1,9 @@
using Base.Test using Base.Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy σ, binarycrossentropy, logitbinarycrossentropy
const ϵ = 1e-7
@testset "losses" begin @testset "losses" begin
# First, regression-style y's # First, regression-style y's
y = [1, 1, 0, 0] y = [1, 1, 0, 0]
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
logŷ, y = randn(3), rand(3) logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin @testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
end end
@testset "logitbinarycrossentropy" begin @testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y) @test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end end
end end