Merge branch 'master' of github.com:jarvist/Flux.jl into HEAD
This commit is contained in:
commit
344a750770
@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i
|
|||||||
|
|
||||||
```julia
|
```julia
|
||||||
model = Chain(
|
model = Chain(
|
||||||
Dense(768, 128),
|
Dense(768, 128, σ),
|
||||||
LSTM(128, 256)
|
LSTM(128, 256)
|
||||||
LSTM(256, 128)
|
LSTM(256, 128)
|
||||||
Dense(128, 10),
|
Dense(128, 10),
|
||||||
|
@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided
|
|||||||
|
|
||||||
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
|
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
|
||||||
|
|
||||||
|
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays – please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.)
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using CuArrays
|
using CuArrays
|
||||||
|
|
||||||
|
@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly
|
|||||||
```
|
```
|
||||||
|
|
||||||
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
|
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
|
||||||
|
|
||||||
|
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.
|
||||||
|
@ -28,13 +28,15 @@ l = loss(x, y)
|
|||||||
back!(l)
|
back!(l)
|
||||||
```
|
```
|
||||||
|
|
||||||
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
W.grad
|
using Flux.Tracker: grad, update!
|
||||||
|
|
||||||
# Update the parameter
|
Δ = grad(W)
|
||||||
W.data .-= 0.1(W.grad)
|
|
||||||
|
# Update the parameter and reset the gradient
|
||||||
|
update!(W, -0.1Δ)
|
||||||
|
|
||||||
loss(x, y) # ~ 2.5
|
loss(x, y) # ~ 2.5
|
||||||
```
|
```
|
||||||
|
@ -7,6 +7,7 @@ add the result to the overall loss.
|
|||||||
For example, say we have a simple regression.
|
For example, say we have a simple regression.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
|
using Flux: crossentropy
|
||||||
m = Dense(10, 5)
|
m = Dense(10, 5)
|
||||||
loss(x, y) = crossentropy(softmax(m(x)), y)
|
loss(x, y) = crossentropy(softmax(m(x)), y)
|
||||||
```
|
```
|
||||||
@ -43,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
|
|||||||
|
|
||||||
loss(rand(28^2), rand(10))
|
loss(rand(28^2), rand(10))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
One can also easily add per-layer regularisation via the `activations` function:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
|
||||||
|
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
|
||||||
|
|
||||||
|
julia> activations(c, rand(10))
|
||||||
|
3-element Array{Any,1}:
|
||||||
|
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
|
||||||
|
param([0.0330606, -0.456104])
|
||||||
|
param([0.61991, 0.38009])
|
||||||
|
|
||||||
|
julia> sum(vecnorm, ans)
|
||||||
|
2.639678767773633 (tracked)
|
||||||
|
```
|
||||||
|
@ -17,16 +17,17 @@ back!(l)
|
|||||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
function update()
|
using Flux.Tracker: grad, update!
|
||||||
|
|
||||||
|
function sgd()
|
||||||
η = 0.1 # Learning Rate
|
η = 0.1 # Learning Rate
|
||||||
for p in (W, b)
|
for p in (W, b)
|
||||||
p.data .-= η .* p.grad # Apply the update
|
update!(p, -η * grad(p))
|
||||||
p.grad .= 0 # Clear the gradient
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
```
|
```
|
||||||
|
|
||||||
If we call `update`, the parameters `W` and `b` will change and our loss should go down.
|
If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
|
||||||
|
|
||||||
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
|
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ include("optimise/Optimise.jl")
|
|||||||
using .Optimise
|
using .Optimise
|
||||||
using .Optimise: @epochs
|
using .Optimise: @epochs
|
||||||
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||||
RMSProp, ADAGrad, ADADelta, AMSGrad
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain)
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Seem to need this for `accumulate`; try removing on 0.7
|
||||||
|
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
|
||||||
|
|
||||||
|
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dense(in::Integer, out::Integer, σ = identity)
|
Dense(in::Integer, out::Integer, σ = identity)
|
||||||
|
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
|
||||||
|
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
|
||||||
|
|
||||||
|
expand(N, i::Tuple) = i
|
||||||
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Conv(size, in=>out)
|
Conv(size, in=>out)
|
||||||
Conv(size, in=>out, relu)
|
Conv(size, in=>out, relu)
|
||||||
@ -10,7 +15,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
|
|||||||
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
|
||||||
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||||
|
|
||||||
Takes the keyword arguments `pad` and `stride`.
|
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||||
"""
|
"""
|
||||||
struct Conv{N,F,A,V}
|
struct Conv{N,F,A,V}
|
||||||
σ::F
|
σ::F
|
||||||
@ -18,17 +23,17 @@ struct Conv{N,F,A,V}
|
|||||||
bias::V
|
bias::V
|
||||||
stride::NTuple{N,Int}
|
stride::NTuple{N,Int}
|
||||||
pad::NTuple{N,Int}
|
pad::NTuple{N,Int}
|
||||||
|
dilation::NTuple{N,Int}
|
||||||
end
|
end
|
||||||
|
|
||||||
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity;
|
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||||
stride = 1, pad = 0) where T =
|
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||||
Conv(σ, w, b, stride, pad)
|
Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
|
||||||
|
|
||||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||||
stride::NTuple{N,Integer} = map(_->1,k),
|
stride = 1, pad = 0, dilation = 1) where N =
|
||||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
|
||||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||||
stride = stride, pad = pad)
|
stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
Flux.treelike(Conv)
|
Flux.treelike(Conv)
|
||||||
|
|
||||||
@ -36,7 +41,7 @@ function (c::Conv)(x)
|
|||||||
# TODO: breaks gpu broadcast :(
|
# TODO: breaks gpu broadcast :(
|
||||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b)
|
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Conv)
|
function Base.show(io::IO, l::Conv)
|
||||||
|
@ -31,15 +31,14 @@ function Dropout(p)
|
|||||||
Dropout{typeof(p)}(p, true)
|
Dropout{typeof(p)}(p, true)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||||
|
|
||||||
function (a::Dropout)(x)
|
function (a::Dropout)(x)
|
||||||
a.active || return x
|
a.active || return x
|
||||||
y = similar(x)
|
y = similar(x)
|
||||||
rand!(y)
|
rand!(y)
|
||||||
q = 1 - a.p
|
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||||
@inbounds for i=1:length(y)
|
return x .* y
|
||||||
y[i] = y[i] > a.p ? 1 / q : 0
|
|
||||||
end
|
|
||||||
return y .* x
|
|
||||||
end
|
end
|
||||||
|
|
||||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||||
|
@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
binarycrossentropy(ŷ, y)
|
binarycrossentropy(ŷ, y; ϵ=eps(ŷ))
|
||||||
|
|
||||||
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
|
||||||
|
|
||||||
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
||||||
3-element Array{Float64,1}:
|
3-element Array{Float64,1}:
|
||||||
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
|||||||
0.352317
|
0.352317
|
||||||
0.86167
|
0.86167
|
||||||
"""
|
"""
|
||||||
binarycrossentropy(ŷ, y) = -y*log(ŷ) - (1 - y)*log(1 - ŷ)
|
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logitbinarycrossentropy(logŷ, y)
|
logitbinarycrossentropy(logŷ, y)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
module Optimise
|
module Optimise
|
||||||
|
|
||||||
export train!,
|
export train!,
|
||||||
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
|
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
|
||||||
|
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
|
||||||
|
|
||||||
struct Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
|
@ -99,3 +99,12 @@ tuning.
|
|||||||
"""
|
"""
|
||||||
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||||
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
|
||||||
|
"""
|
||||||
|
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
|
||||||
|
|
||||||
|
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
|
||||||
|
than learning rate don't need tuning.
|
||||||
|
"""
|
||||||
|
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
|
||||||
|
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
|
||||||
|
@ -35,7 +35,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||||||
acc = zeros(p.x)
|
acc = zeros(p.x)
|
||||||
function ()
|
function ()
|
||||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||||
@. p.Δ *= η / (√acc + ϵ)
|
@. p.Δ *= η / √(acc + ϵ)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|||||||
acc = zeros(p.x) .+ ϵ
|
acc = zeros(p.x) .+ ϵ
|
||||||
function ()
|
function ()
|
||||||
@. acc += p.Δ^2
|
@. acc += p.Δ^2
|
||||||
@. p.Δ *= η / √acc
|
@. p.Δ *= η / √(acc + ϵ)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||||||
function ()
|
function ()
|
||||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||||
@. p.Δ = mt / (1 - β1p) / (√(vt / (1 - β2p)) + ϵ) * η
|
@. p.Δ = mt / (1 - β1p) / √(vt / (1 - β2p) + ϵ) * η
|
||||||
β1p *= β1
|
β1p *= β1
|
||||||
β2p *= β2
|
β2p *= β2
|
||||||
end
|
end
|
||||||
@ -94,6 +94,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||||
|
mt = zeros(p.x)
|
||||||
|
vt = zeros(p.x)
|
||||||
|
β1p, β2p = β1, β2
|
||||||
|
function ()
|
||||||
|
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||||
|
@. vt = β2 * vt + (1 - β2) * p.Δ^2
|
||||||
|
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / √(vt * β2 / (1 - β2p) + ϵ) * η
|
||||||
|
β1p *= β1
|
||||||
|
β2p *= β2
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
|
||||||
|
|
||||||
function expdecay(p::Param, γ::Real)
|
function expdecay(p::Param, γ::Real)
|
||||||
|
@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
isinf(l) && error("Loss is Inf")
|
|
||||||
isnan(l) && error("Loss is NaN")
|
|
||||||
@interrupts back!(l)
|
@interrupts back!(l)
|
||||||
opt()
|
opt()
|
||||||
cb() == :stop && break
|
cb() == :stop && break
|
||||||
|
@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing
|
|||||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||||
grad(x) = grad(tracker(x))
|
grad(x) = grad(tracker(x))
|
||||||
|
grad(::Void) = nothing
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
struct Call{F,As<:Tuple}
|
||||||
func::F
|
func::F
|
||||||
@ -46,11 +47,27 @@ isleaf(x::Tracked) = x.f == Call(nothing)
|
|||||||
data(x::Tracked) = x.data
|
data(x::Tracked) = x.data
|
||||||
grad(x::Tracked) = x.grad
|
grad(x::Tracked) = x.grad
|
||||||
|
|
||||||
|
function update!(x, Δ)
|
||||||
|
tracker(x).data += Δ
|
||||||
|
tracker(x).grad .= 0
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
|
||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("scalar.jl")
|
include("scalar.jl")
|
||||||
include("array.jl")
|
include("array.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
|
"""
|
||||||
|
hook(f, x) -> x′
|
||||||
|
|
||||||
|
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||||
|
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
||||||
|
the sign of the gradient applied to `x`.
|
||||||
|
"""
|
||||||
|
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||||
|
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
|
||||||
|
|
||||||
param(x::Number) = TrackedReal(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
|
||||||
|
@ -93,6 +93,26 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
|||||||
back(xs, Δ′)
|
back(xs, Δ′)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
|
||||||
|
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
|
||||||
|
|
||||||
|
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
|
||||||
|
Δ′ = similar(xs.data)
|
||||||
|
Δ′ .= 0
|
||||||
|
S = size(xs.data)
|
||||||
|
|
||||||
|
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||||
|
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
|
||||||
|
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||||
|
# wrap around based on original size S.
|
||||||
|
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||||
|
Δ′[src_idx...] += val
|
||||||
|
end
|
||||||
|
back(xs, Δ′)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
for f in [:vcat, :hcat]
|
for f in [:vcat, :hcat]
|
||||||
@eval begin
|
@eval begin
|
||||||
# This section is a bit of a hack since julia doesn't have a standardised
|
# This section is a bit of a hack since julia doesn't have a standardised
|
||||||
@ -314,18 +334,18 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||||||
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
||||||
|
|
||||||
# TODO: can store kwargs efficiently in namedtuples
|
# TODO: can store kwargs efficiently in namedtuples
|
||||||
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
|
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
|
||||||
|
|
||||||
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||||
track(_conv, x, w, stride, pad)
|
track(_conv, x, w, stride, pad, dilation)
|
||||||
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||||
track(_conv, x, w, stride, pad)
|
track(_conv, x, w, stride, pad, dilation)
|
||||||
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
|
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||||
track(_conv, x, w, stride, pad)
|
track(_conv, x, w, stride, pad, dilation)
|
||||||
|
|
||||||
function back(::typeof(_conv), Δ, x, w, stride, pad)
|
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
|
||||||
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
|
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
|
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||||
end
|
end
|
||||||
|
|
||||||
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
function gradient(f, xs::AbstractArray...)
|
function gradient(f, xs...)
|
||||||
xs = param.(xs)
|
xs = param.(xs)
|
||||||
back!(f(xs...))
|
back!(f(xs...))
|
||||||
grad.(xs)
|
grad.(xs)
|
||||||
|
@ -8,7 +8,11 @@ tracker(x::TrackedReal) = x.tracker
|
|||||||
|
|
||||||
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
||||||
|
|
||||||
back!(x::TrackedReal) = back!(x, 1)
|
function back!(x::TrackedReal)
|
||||||
|
isinf(x) && error("Loss is Inf")
|
||||||
|
isnan(x) && error("Loss is NaN")
|
||||||
|
return back!(x, 1)
|
||||||
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, x::TrackedReal)
|
function Base.show(io::IO, x::TrackedReal)
|
||||||
show(io, data(x))
|
show(io, data(x))
|
||||||
@ -19,14 +23,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
|||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
|
||||||
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
|
||||||
|
|
||||||
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||||
|
|
||||||
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||||
|
error("Not implemented: convert tracked $S to tracked $T")
|
||||||
|
|
||||||
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||||
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||||
|
|
||||||
|
Base.eps(x::TrackedReal) = eps(data(x))
|
||||||
|
|
||||||
for f in :[isinf, isnan, isfinite].args
|
for f in :[isinf, isnan, isfinite].args
|
||||||
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||||
end
|
end
|
||||||
@ -91,3 +97,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
|||||||
|
|
||||||
back(::typeof(getindex), Δ, t, i) =
|
back(::typeof(getindex), Δ, t, i) =
|
||||||
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
|
||||||
|
|
||||||
|
# Array collection
|
||||||
|
|
||||||
|
function collect(xs)
|
||||||
|
xs = Base.collect(xs)
|
||||||
|
track(Call(collect, xs), data.(xs))
|
||||||
|
end
|
||||||
|
|
||||||
|
function scan(c::Call{typeof(collect)})
|
||||||
|
foreach(scan, c.args[1])
|
||||||
|
end
|
||||||
|
|
||||||
|
function back(::typeof(collect), Δ, xs)
|
||||||
|
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
|
||||||
|
end
|
||||||
|
@ -2,6 +2,8 @@ using Base.Test
|
|||||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||||
σ, binarycrossentropy, logitbinarycrossentropy
|
σ, binarycrossentropy, logitbinarycrossentropy
|
||||||
|
|
||||||
|
const ϵ = 1e-7
|
||||||
|
|
||||||
@testset "losses" begin
|
@testset "losses" begin
|
||||||
# First, regression-style y's
|
# First, regression-style y's
|
||||||
y = [1, 1, 0, 0]
|
y = [1, 1, 0, 0]
|
||||||
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
|||||||
|
|
||||||
logŷ, y = randn(3), rand(3)
|
logŷ, y = randn(3), rand(3)
|
||||||
@testset "binarycrossentropy" begin
|
@testset "binarycrossentropy" begin
|
||||||
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
@test binarycrossentropy.(σ.(logŷ), y; ϵ=0) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||||
|
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "logitbinarycrossentropy" begin
|
@testset "logitbinarycrossentropy" begin
|
||||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y)
|
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -3,7 +3,7 @@ using Flux.Tracker
|
|||||||
|
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
|
||||||
w′ = param(randn(10, 10))
|
w′ = param(randn(10, 10))
|
||||||
loss(x) = Flux.mse(w*x, w′*x)
|
loss(x) = Flux.mse(w*x, w′*x)
|
||||||
opt = Opt([w′])
|
opt = Opt([w′])
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux.Tracker, Base.Test, NNlib
|
||||||
using Flux.Tracker: TrackedReal, gradcheck
|
using Flux.Tracker: TrackedReal, gradcheck, grad
|
||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
@ -114,6 +114,9 @@ end
|
|||||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||||
|
|
||||||
|
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||||
|
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||||
|
|
||||||
@test gradtest(kron, rand(5), rand(3))
|
@test gradtest(kron, rand(5), rand(3))
|
||||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||||
@test gradtest(kron, rand(5,1), rand(3,1))
|
@test gradtest(kron, rand(5,1), rand(3,1))
|
||||||
@ -220,4 +223,13 @@ b = param(rand())
|
|||||||
Tracker.back!(b)
|
Tracker.back!(b)
|
||||||
@test Tracker.grad(b) == 1
|
@test Tracker.grad(b) == 1
|
||||||
|
|
||||||
|
@testset "collect" begin
|
||||||
|
x, y = param(2), param(3)
|
||||||
|
xy = Tracker.collect([x, y])
|
||||||
|
@test xy isa TrackedArray{Float64}
|
||||||
|
z = xy[1]*xy[2]
|
||||||
|
back!(z)
|
||||||
|
@test grad.((x,y)) == (3, 2)
|
||||||
|
end
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user