Merge branch 'master' of github.com:jarvist/Flux.jl into HEAD

This commit is contained in:
Jarvist Moore Frost 2018-07-03 11:15:43 +01:00
commit 344a750770
22 changed files with 180 additions and 53 deletions

View File

@ -30,7 +30,7 @@ Flux has powerful high-level features, and common architectures can be defined i
```julia ```julia
model = Chain( model = Chain(
Dense(768, 128), Dense(768, 128, σ),
LSTM(128, 256) LSTM(128, 256)
LSTM(256, 128) LSTM(256, 128)
Dense(128, 10), Dense(128, 10),

View File

@ -4,6 +4,8 @@ Support for array operations on other hardware backends, like GPUs, is provided
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU. For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
(Note that you need to build Julia 0.6 from source and have CUDA available to use CuArrays please see the [CUDAnative.jl](https://github.com/JuliaGPU/CUDAnative.jl) instructions for more details.)
```julia ```julia
using CuArrays using CuArrays

View File

@ -14,3 +14,5 @@ Pkg.test("Flux") # Check things installed correctly
``` ```
Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models. Start with the [basics](models/basics.md). The [model zoo](https://github.com/FluxML/model-zoo/) is also a good starting point for many common kinds of models.
See [GPU support](gpu.md) for more details on installing and using Flux with GPUs.

View File

@ -28,13 +28,15 @@ l = loss(x, y)
back!(l) back!(l)
``` ```
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. `loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
```julia ```julia
W.grad using Flux.Tracker: grad, update!
# Update the parameter Δ = grad(W)
W.data .-= 0.1(W.grad)
# Update the parameter and reset the gradient
update!(W, -0.1Δ)
loss(x, y) # ~ 2.5 loss(x, y) # ~ 2.5
``` ```

View File

@ -7,6 +7,7 @@ add the result to the overall loss.
For example, say we have a simple regression. For example, say we have a simple regression.
```julia ```julia
using Flux: crossentropy
m = Dense(10, 5) m = Dense(10, 5)
loss(x, y) = crossentropy(softmax(m(x)), y) loss(x, y) = crossentropy(softmax(m(x)), y)
``` ```
@ -43,3 +44,19 @@ loss(x, y) = crossentropy(m(x), y) + sum(vecnorm, params(m))
loss(rand(28^2), rand(10)) loss(rand(28^2), rand(10))
``` ```
One can also easily add per-layer regularisation via the `activations` function:
```julia
julia> c = Chain(Dense(10,5,σ),Dense(5,2),softmax)
Chain(Dense(10, 5, NNlib.σ), Dense(5, 2), NNlib.softmax)
julia> activations(c, rand(10))
3-element Array{Any,1}:
param([0.71068, 0.831145, 0.751219, 0.227116, 0.553074])
param([0.0330606, -0.456104])
param([0.61991, 0.38009])
julia> sum(vecnorm, ans)
2.639678767773633 (tracked)
```

View File

@ -17,16 +17,17 @@ back!(l)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia ```julia
function update() using Flux.Tracker: grad, update!
function sgd()
η = 0.1 # Learning Rate η = 0.1 # Learning Rate
for p in (W, b) for p in (W, b)
p.data .-= η .* p.grad # Apply the update update!(p, -η * grad(p))
p.grad .= 0 # Clear the gradient
end end
end end
``` ```
If we call `update`, the parameters `W` and `b` will change and our loss should go down. If we call `sgd`, the parameters `W` and `b` will change and our loss should go down.
There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum. There are two pieces here: one is that we need a list of trainable parameters for the model (`[W, b]` in this case), and the other is the update step. In this case the update is simply gradient descent (`x .-= η .* Δ`), but we might choose to do something more advanced, like adding momentum.

View File

@ -23,7 +23,7 @@ include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs
export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, export SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
include("utils.jl") include("utils.jl")
include("onehot.jl") include("onehot.jl")

View File

@ -38,6 +38,11 @@ function Base.show(io::IO, c::Chain)
print(io, ")") print(io, ")")
end end
# Seem to need this for `accumulate`; try removing on 0.7
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
""" """
Dense(in::Integer, out::Integer, σ = identity) Dense(in::Integer, out::Integer, σ = identity)

View File

@ -1,5 +1,10 @@
using NNlib: conv using NNlib: conv
@generated sub2(::Type{Val{N}}) where N = :(Val{$(N-2)})
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
""" """
Conv(size, in=>out) Conv(size, in=>out)
Conv(size, in=>out, relu) Conv(size, in=>out, relu)
@ -10,7 +15,7 @@ Standard convolutional layer. `size` should be a tuple like `(2, 2)`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
struct Conv{N,F,A,V} struct Conv{N,F,A,V}
σ::F σ::F
@ -18,17 +23,17 @@ struct Conv{N,F,A,V}
bias::V bias::V
stride::NTuple{N,Int} stride::NTuple{N,Int}
pad::NTuple{N,Int} pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end end
Conv(w::AbstractArray{T}, b::AbstractVector{T}, σ = identity; Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where T = stride = 1, pad = 0, dilation = 1) where {T,N} =
Conv(σ, w, b, stride, pad) Conv(σ, w, b, expand.(sub2(Val{N}), (stride, pad, dilation))...)
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride::NTuple{N,Integer} = map(_->1,k), stride = 1, pad = 0, dilation = 1) where N =
pad::NTuple{N,Integer} = map(_->0,k)) where N =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad) stride = stride, pad = pad, dilation = dilation)
Flux.treelike(Conv) Flux.treelike(Conv)
@ -36,7 +41,7 @@ function (c::Conv)(x)
# TODO: breaks gpu broadcast :( # TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad) .+ b) σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end end
function Base.show(io::IO, l::Conv) function Base.show(io::IO, l::Conv)

View File

@ -31,15 +31,14 @@ function Dropout(p)
Dropout{typeof(p)}(p, true) Dropout{typeof(p)}(p, true)
end end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x) function (a::Dropout)(x)
a.active || return x a.active || return x
y = similar(x) y = similar(x)
rand!(y) rand!(y)
q = 1 - a.p y .= _dropout_kernel.(y, a.p, 1 - a.p)
@inbounds for i=1:length(y) return x .* y
y[i] = y[i] > a.p ? 1 / q : 0
end
return y .* x
end end
_testmode!(a::Dropout, test) = (a.active = !test) _testmode!(a::Dropout, test) = (a.active = !test)

View File

@ -15,9 +15,9 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight
end end
""" """
binarycrossentropy(, y) binarycrossentropy(, y; ϵ=eps())
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerical stability.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}: 3-element Array{Float64,1}:
@ -25,7 +25,7 @@ Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
0.352317 0.352317
0.86167 0.86167
""" """
binarycrossentropy(, y) = -y*log() - (1 - y)*log(1 - ) binarycrossentropy(, y; ϵ=eps()) = -y*log( + ϵ) - (1 - y)*log(1 - + ϵ)
""" """
logitbinarycrossentropy(logŷ, y) logitbinarycrossentropy(logŷ, y)

View File

@ -1,7 +1,8 @@
module Optimise module Optimise
export train!, export train!,
SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad SGD, ADAM, ADAMW, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM
struct Param{T} struct Param{T}
x::T x::T

View File

@ -99,3 +99,12 @@ tuning.
""" """
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) = AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1)) optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))

View File

@ -35,7 +35,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x) acc = zeros(p.x)
function () function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2 @. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ) @. p.Δ *= η / (acc + ϵ)
end end
end end
@ -43,7 +43,7 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ acc = zeros(p.x) .+ ϵ
function () function ()
@. acc += p.Δ^2 @. acc += p.Δ^2
@. p.Δ *= η / acc @. p.Δ *= η / (acc + ϵ)
end end
end end
@ -64,7 +64,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
function () function ()
@. mt = β1 * mt + (1 - β1) * p.Δ @. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2 @. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η @. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1 β1p *= β1
β2p *= β2 β2p *= β2
end end
@ -94,6 +94,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
end end
end end
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end
clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh) clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)
function expdecay(p::Param, γ::Real) function expdecay(p::Param, γ::Real)

View File

@ -37,8 +37,6 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt) opt = runall(opt)
@progress for d in data @progress for d in data
l = loss(d...) l = loss(d...)
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
@interrupts back!(l) @interrupts back!(l)
opt() opt()
cb() == :stop && break cb() == :stop && break

View File

@ -10,6 +10,7 @@ istracked(x) = tracker(x) ≠ nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x)) isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x)) grad(x) = grad(tracker(x))
grad(::Void) = nothing
struct Call{F,As<:Tuple} struct Call{F,As<:Tuple}
func::F func::F
@ -46,11 +47,27 @@ isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data data(x::Tracked) = x.data
grad(x::Tracked) = x.grad grad(x::Tracked) = x.grad
function update!(x, Δ)
tracker(x).data += Δ
tracker(x).grad .= 0
return x
end
include("back.jl") include("back.jl")
include("scalar.jl") include("scalar.jl")
include("array.jl") include("array.jl")
include("numeric.jl") include("numeric.jl")
"""
hook(f, x) -> x
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
param(x::Number) = TrackedReal(float(x)) param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs)) param(xs::AbstractArray) = TrackedArray(float.(xs))

View File

@ -93,6 +93,26 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
back(xs, Δ′) back(xs, Δ′)
end end
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
Δ′ = similar(xs.data)
Δ′ .= 0
S = size(xs.data)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val
end
back(xs, Δ′)
end
for f in [:vcat, :hcat] for f in [:vcat, :hcat]
@eval begin @eval begin
# This section is a bit of a hack since julia doesn't have a standardised # This section is a bit of a hack since julia doesn't have a standardised
@ -314,18 +334,18 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples # TODO: can store kwargs efficiently in namedtuples
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad) _conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad) track(_conv, x, w, stride, pad, dilation)
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N = conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad) track(_conv, x, w, stride, pad, dilation)
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N = conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
track(_conv, x, w, stride, pad) track(_conv, x, w, stride, pad, dilation)
function back(::typeof(_conv), Δ, x, w, stride, pad) function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad)) @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad)) @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
end end
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride) _maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)

View File

@ -1,4 +1,4 @@
function gradient(f, xs::AbstractArray...) function gradient(f, xs...)
xs = param.(xs) xs = param.(xs)
back!(f(xs...)) back!(f(xs...))
grad.(xs) grad.(xs)

View File

@ -8,7 +8,11 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
back!(x::TrackedReal) = back!(x, 1) function back!(x::TrackedReal)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
end
function Base.show(io::IO, x::TrackedReal) function Base.show(io::IO, x::TrackedReal)
show(io, data(x)) show(io, data(x))
@ -19,14 +23,16 @@ Base.decompose(x::TrackedReal) = Base.decompose(data(x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x)) Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y) Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
Base.eps(x::TrackedReal) = eps(data(x))
for f in :[isinf, isnan, isfinite].args for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedReal) = Base.$f(data(x)) @eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end end
@ -91,3 +97,18 @@ Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) = back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t))) back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, xs), data.(xs))
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back(::typeof(collect), Δ, xs)
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
end

View File

@ -2,6 +2,8 @@ using Base.Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy σ, binarycrossentropy, logitbinarycrossentropy
const ϵ = 1e-7
@testset "losses" begin @testset "losses" begin
# First, regression-style y's # First, regression-style y's
y = [1, 1, 0, 0] y = [1, 1, 0, 0]
@ -40,10 +42,11 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
logŷ, y = randn(3), rand(3) logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin @testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) @test binarycrossentropy.(σ.(logŷ), y; ϵ=0) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 - y).*log.(1 - σ.(logŷ) .+ eps.(σ.(logŷ)))
end end
@testset "logitbinarycrossentropy" begin @testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y) @test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end end
end end

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad] @testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
w = param(randn(10, 10)) w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x) loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w]) opt = Opt([w])

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck using Flux.Tracker: TrackedReal, gradcheck, grad
using NNlib: conv using NNlib: conv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -114,6 +114,9 @@ end
@test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5))
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test gradtest(kron, rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8)) @test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron, rand(5,1), rand(3,1)) @test gradtest(kron, rand(5,1), rand(3,1))
@ -220,4 +223,13 @@ b = param(rand())
Tracker.back!(b) Tracker.back!(b)
@test Tracker.grad(b) == 1 @test Tracker.grad(b) == 1
@testset "collect" begin
x, y = param(2), param(3)
xy = Tracker.collect([x, y])
@test xy isa TrackedArray{Float64}
z = xy[1]*xy[2]
back!(z)
@test grad.((x,y)) == (3, 2)
end
end #testset end #testset