Merge branch 'master' of https://github.com/FluxML/Flux.jl
This commit is contained in:
commit
c38d4edef7
@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i
|
||||
|
||||
```julia
|
||||
model = Chain(
|
||||
Dense(768, 128),
|
||||
Dense(768, 128, σ),
|
||||
LSTM(128, 256)
|
||||
LSTM(256, 128)
|
||||
Dense(128, 10),
|
||||
|
@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided
|
||||
|
||||
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
|
||||
|
||||
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.)
|
||||
|
||||
```julia
|
||||
using CuArrays
|
||||
|
||||
|
@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly
|
||||
```
|
||||
|
||||
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
|
||||
|
||||
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.
|
||||
|
@ -28,13 +28,15 @@ l = loss(x, y)
|
||||
back!(l)
|
||||
```
|
||||
|
||||
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
||||
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
||||
|
||||
```julia
|
||||
W.grad
|
||||
using Flux.Tracker: grad, update!
|
||||
|
||||
# Update the parameter
|
||||
W.data .-= 0.1(W.grad)
|
||||
Δ = grad(W)
|
||||
|
||||
# Update the parameter and reset the gradient
|
||||
update!(W, -0.1Δ)
|
||||
|
||||
loss(x, y) # ~ 2.5
|
||||
```
|
||||
|
@ -44,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
|
||||
|
||||
loss(rand(28^2), rand(10))
|
||||
```
|
||||
|
||||
One can also easily add per-layer regularisation via the `activations` function:
|
||||
|
||||
```julia
|
||||
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
|
||||
|
||||
julia> activations(c, rand(10))
|
||||
3-element Array{Any,1}:
|
||||
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
|
||||
param([0.0330606, -0.456104])
|
||||
param([0.61991, 0.38009])
|
||||
|
||||
julia> sum(vecnorm, ans)
|
||||
2.639678767773633 (tracked)
|
||||
```
|
||||
|
@ -17,16 +17,17 @@ back!(l)
|
||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
||||
```julia
|
||||
function update()
|
||||
using Flux.Tracker: grad, update!
|
||||
|
||||
function sgd()
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
p.data .-= η .* p.grad # Apply the update
|
||||
p.grad .= 0 # Clear the gradient
|
||||
update!(p, -η * grad(p))
|
||||
end
|
||||
end
|
||||
```
|
||||
|
||||
If we call `update`, the parameters `W` and `b` will change and our loss should go down.
|
||||
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
|
||||
|
||||
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
|
||||
|
||||
|
@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
# Seem to need this for `accumulate`; try removing on 0.7
|
||||
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
|
||||
|
||||
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
|
||||
|
||||
"""
|
||||
Dense(in::Integer, out::Integer, σ = identity)
|
||||
|
||||
|
@ -1,5 +1,10 @@
|
||||
using NNlib: conv
|
||||
|
||||
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
|
||||
|
||||
expand(N, i::Tuple) = i
|
||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||
|
||||
"""
|
||||
Conv(size, in=>out)
|
||||
Conv(size, in=>out, relu)
|
||||
@ -21,14 +26,12 @@ struct Conv{N,F,A,V}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation=1) where T =
|
||||
Conv(σ, w, b, stride, pad, dilation)
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
|
||||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k),
|
||||
dilation::NTuple{N,Integer} = map(_->1,k)) where N =
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
|
@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
|
||||
end
|
||||
|
||||
"""
|
||||
binarycrossentropy(ŷ, y)
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
||||
|
||||
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
||||
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
||||
|
||||
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
||||
3-element Array{Float64,1}:
|
||||
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
||||
0.352317
|
||||
0.86167
|
||||
"""
|
||||
binarycrossentropy(ŷ, y) = -y*log(ŷ) - (1 - y)*log(1 - ŷ)
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
"""
|
||||
logitbinarycrossentropy(logŷ, y)
|
||||
|
@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
|
||||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
isinf(l) && error("Loss is Inf")
|
||||
isnan(l) && error("Loss is NaN")
|
||||
@interrupts back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
|
@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||
grad(x) = grad(tracker(x))
|
||||
grad(::Void) = nothing
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
@ -46,11 +47,27 @@ isleaf(x::Tracked) = x.f == Call(nothing)
|
||||
data(x::Tracked) = x.data
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
function update!(x, Δ)
|
||||
tracker(x).data += Δ
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
include("back.jl")
|
||||
include("scalar.jl")
|
||||
include("array.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
"""
|
||||
hook(f, x) -> x′
|
||||
|
||||
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
function gradient(f, xs::AbstractArray...)
|
||||
function gradient(f, xs...)
|
||||
xs = param.(xs)
|
||||
back!(f(xs...))
|
||||
grad.(xs)
|
||||
|
@ -8,7 +8,11 @@ tracker(x::TrackedReal) = x.tracker
|
||||
|
||||
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
||||
|
||||
back!(x::TrackedReal) = back!(x, 1)
|
||||
function back!(x::TrackedReal)
|
||||
isinf(x) && error("Loss is Inf")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
return back!(x, 1)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
show(io, data(x))
|
||||
@ -19,15 +23,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||
|
||||
# This cuts derivatives, fix if needed.
|
||||
# Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||
# TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||
|
||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||
|
||||
Base.eps(x::TrackedReal) = eps(data(x))
|
||||
|
||||
for f in :[isinf, isnan, isfinite].args
|
||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||
end
|
||||
|
@ -1,7 +1,9 @@
|
||||
using Base.Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
|
||||
const ϵ = 1e-7
|
||||
|
||||
@testset "losses" begin
|
||||
# First, regression-style y's
|
||||
y = [1, 1, 0, 0]
|
||||
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
|
||||
logŷ, y = randn(3), rand(3)
|
||||
@testset "binarycrossentropy" begin
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||
end
|
||||
|
||||
|
||||
@testset "logitbinarycrossentropy" begin
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y)
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user