Resolving Merge Conflicts

This commit is contained in:
thebhatman 2019-06-12 21:34:42 +05:30
commit c7c0ee2cbc
13 changed files with 244 additions and 45 deletions

29
CITATION.bib Normal file
View File

@ -0,0 +1,29 @@
@article{Flux.jl-2018,
author = {Michael Innes and
Elliot Saba and
Keno Fischer and
Dhairya Gandhi and
Marco Concetto Rudilosso and
Neethu Mariya Joy and
Tejan Karmali and
Avik Pal and
Viral Shah},
title = {Fashionable Modelling with Flux},
journal = {CoRR},
volume = {abs/1811.01457},
year = {2018},
url = {http://arxiv.org/abs/1811.01457},
archivePrefix = {arXiv},
eprint = {1811.01457},
timestamp = {Thu, 22 Nov 2018 17:58:30 +0100},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1811-01457},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{innes:2018,
author = {Mike Innes},
title = {Flux: Elegant Machine Learning with Julia},
journal = {Journal of Open Source Software},
year = {2018},
doi = {10.21105/joss.00602},
}

View File

@ -1,5 +1,10 @@
# v0.9.0
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
# v0.8.0 # v0.8.0
* [Dropout now has a `dims` argument for specifying the unbroadcast dimensions.](https://github.com/FluxML/Flux.jl/pull/563)
* New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311). * New [ConvTranspose layer](https://github.com/FluxML/Flux.jl/pull/311).
* New [Maxout layer](https://github.com/FluxML/Flux.jl/pull/647) * New [Maxout layer](https://github.com/FluxML/Flux.jl/pull/647)
* Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption. * Datasets are now [hash verified on download](https://github.com/FluxML/Flux.jl/pull/585) to avoid corruption.
@ -13,6 +18,7 @@
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`. * [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022). * New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494). * New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
* New [CrossCor](https://github.com/FluxML/Flux.jl/pull/762).
AD Changes: AD Changes:

View File

@ -17,6 +17,7 @@ MaxPool
MeanPool MeanPool
DepthwiseConv DepthwiseConv
ConvTranspose ConvTranspose
CrossCor
``` ```
## Recurrent Layers ## Recurrent Layers
@ -36,17 +37,7 @@ But in contrast to the layers described in the other sections are not readily gr
```@docs ```@docs
Maxout Maxout
``` SkipConnection
# Normalisation & Regularisation
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
```@docs
Flux.testmode!
BatchNorm
Dropout
LayerNorm
``` ```
## Activation Functions ## Activation Functions

View File

@ -18,7 +18,7 @@ And you use less memory.
Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1), Not only should your activation functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
they should also preserve the type of their inputs. they should also preserve the type of their inputs.
A very artificial example using an activatioon function like A very artificial example using an activation function like
``` ```
my_tanh(x) = Float64(tanh(x)) my_tanh(x) = Float64(tanh(x))
@ -73,4 +73,4 @@ end
``` ```
When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`. When doing this kind of concatenation use `reduce(hcat, xs)` rather than `hcat(xs...)`.
This will avoid the splatting penality, and will hit the optimised `reduce` method. This will avoid the splatting penalty, and will hit the optimised `reduce` method.

View File

@ -10,7 +10,7 @@ using Zygote: Params, @adjoint, gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64, param, data SkipConnection,params, mapleaves, cpu, gpu, f32, f64, param, data
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise

View File

@ -190,5 +190,37 @@ function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end end
"""
SkipConnection(layers...)
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
and a shortcut connection linking the input to the block to the
output through a user-supplied callable.
`SkipConnection` requires the output dimension to be the same as the input.
A 'ResNet'-type skip-connection with identity shortcut would simply be
```julia
SkipConnection(layer, (a,b) -> a + b)
```
"""
struct SkipConnection
layers
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end
@treelike SkipConnection
function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input
skip.connection(skip.layers(input), input)
end
function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(")
join(io, b.layers, ", ")
print(io, ")")
end
param(x) = x param(x) = x
data(x) = x data(x) = x

View File

@ -14,11 +14,11 @@ Example: Applying Conv layer to a 1-channel input using a 2x2 window size,
size = (2,2) size = (2,2)
in = 1 in = 1
out = 16 out = 16
Conv((2, 2), 1=>16, relu) Conv((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches). Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array, In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array. and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`. Takes the keyword arguments `pad`, `stride` and `dilation`.
@ -136,18 +136,17 @@ end
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = (a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x)) a(T.(x))
""" """
DepthwiseConv(size, in) DepthwiseConv(size, in=>out)
DepthwiseConv(size, in=>mul) DepthwiseConv(size, in=>out, relu)
DepthwiseConv(size, in=>mul, relu)
Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`. Depthwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `mul` specify the number of input channels and channel multiplier respectively. `in` and `out` specify the number of input and output channels respectively.
In case the `mul` is not specified it is taken as 1. Note that `out` must be an integer multiple of `in`.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad` and `stride`. Takes the keyword arguments `pad`, `stride` and `dilation`.
""" """
struct DepthwiseConv{N,M,F,A,V} struct DepthwiseConv{N,M,F,A,V}
σ::F σ::F
@ -187,8 +186,8 @@ function (c::DepthwiseConv)(x)
end end
function Base.show(io::IO, l::DepthwiseConv) function Base.show(io::IO, l::DepthwiseConv)
print(io, "DepthwiseConv(", size(l.weight)[1:ndims(l.weight)-2]) print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1)) print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end]))
l.σ == identity || print(io, ", ", l.σ) l.σ == identity || print(io, ", ", l.σ)
print(io, ")") print(io, ")")
end end
@ -198,6 +197,76 @@ end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = (a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x)) a(T.(x))
"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
giving us a 16-channel output. Output is activated with ReLU.
size = (2,2)
in = 1
out = 16
CrossCor((2, 2), 1=>16, relu)
Data should be stored in WHCN order (width, height, # channels, # batches).
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct CrossCor{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
function CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return CrossCor(σ, w, b, stride, pad, dilation)
end
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike CrossCor
function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
return conv(x, w, ddims)
end
function (c::CrossCor)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(crosscor(x, c.weight, cdims) .+ b)
end
function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
""" """
MaxPool(k) MaxPool(k)

View File

@ -3,11 +3,12 @@ istraining() = false
@adjoint istraining() = true, _ -> nothing @adjoint istraining() = true, _ -> nothing
""" """
Dropout(p) Dropout(p, dims = :)
A Dropout layer. For each input, either sets that input to `0` (with probability A Dropout layer. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it `p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
reduces overfitting during training. dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
Does nothing to the input once in [`testmode!`](@ref). Does nothing to the input once in [`testmode!`](@ref).
""" """
@ -19,13 +20,16 @@ mutable struct Dropout{F}
end end
end end
_dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x) function (a::Dropout)(x)
istraining() || return x istraining() || return x
y = similar(x) y = similar(x)
rand!(y) rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p) y .= _dropout_kernel.(y, p, 1 - p)
return x .* y return x .* y
end end

View File

@ -36,6 +36,10 @@ c = gpu(Conv((2,2),3=>4))
l = c(gpu(rand(10,10,3,2))) l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l)) Flux.back!(sum(l))
c = gpu(CrossCor((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))
end end
@testset "onecold gpu" begin @testset "onecold gpu" begin

View File

@ -1,16 +1,17 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
trainmode(f, x...) = forward(f, x...)[1]
#
# @testset "CUDNN BatchNorm" begin # @testset "CUDNN BatchNorm" begin
# @testset "4D Input" begin # @testset "4D Input" begin
# x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1)))) # x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
# m = BatchNorm(3) # m = BatchNorm(3)
# cx = gpu(x) # cx = gpu(x)
# cm = gpu(m) # cm = gpu(m)
# #
# y = m(x) # y = trainmode(m, x)
# cy = cm(cx) # cy = trainmode(cm, cx)
# #
# @test cy isa TrackedArray{Float32,4,CuArray{Float32,4}} # # @test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
# #
# @test cpu(data(cy)) ≈ data(y) # @test cpu(data(cy)) ≈ data(y)
# #

View File

@ -72,4 +72,16 @@ import Flux: activations
@test length(ps) == 8 #4 alts, each with weight and bias @test length(ps) == 8 #4 alts, each with weight and bias
end end
end end
@testset "SkipConnection" begin
@testset "zero sum" begin
input = randn(10, 10, 10, 10)
@test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input
end
@testset "concat size" begin
input = randn(10, 2)
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
end
end
end end

View File

@ -39,20 +39,14 @@ end
@testset "Depthwise Conv" begin @testset "Depthwise Conv" begin
r = zeros(Float32, 28, 28, 3, 5) r = zeros(Float32, 28, 28, 3, 5)
m1 = DepthwiseConv((2, 2), 3=>5) m1 = DepthwiseConv((2, 2), 3=>15)
@test size(m1(r), 3) == 15 @test size(m1(r), 3) == 15
m2 = DepthwiseConv((2, 2), 3)
@test size(m2(r), 3) == 3
x = zeros(Float64, 28, 28, 3, 5) m3 = DepthwiseConv((2, 3), 3=>9)
@test size(m3(r), 3) == 9
m3 = DepthwiseConv((2, 2), 3 => 5) # Test that we cannot ask for non-integer multiplication factors
@test_throws AssertionError DepthwiseConv((2,2), 3=>10)
@test size(m3(r), 3) == 15
m4 = DepthwiseConv((2, 2), 3)
@test size(m4(r), 3) == 3
end end
@testset "ConvTranspose" begin @testset "ConvTranspose" begin
@ -61,3 +55,50 @@ end
x_hat = ConvTranspose((3, 3), 1 => 1)(y) x_hat = ConvTranspose((3, 3), 1 => 1)(y)
@test size(x_hat) == size(x) @test size(x_hat) == size(x)
end end
@testset "CrossCor" begin
x = rand(Float32, 28, 28, 1, 1)
w = rand(2,2,1,1)
y = CrossCor(w, [0.0])
@test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1]
r = zeros(Float32, 28, 28, 1, 5)
m = Chain(
CrossCor((2, 2), 1=>16, relu),
MaxPool((2,2)),
CrossCor((2, 2), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax)
@test size(m(r)) == (10, 5)
@test y(x) != Conv(w, [0.0])(x)
@test CrossCor(w[end:-1:1, end:-1:1, :, :], [0.0])(x) == Conv(w, [0.0])(x)
end
@testset "Conv with non quadratic window #700" begin
data = zeros(Float32, 7,7,1,1)
data[4,4,1,1] = 1
l = Conv((3,3), 1=>1)
expected = zeros(eltype(l.weight),5,5,1,1)
expected[2:end-1,2:end-1,1,1] = l.weight
@test expected == l(data)
l = Conv((3,1), 1=>1)
expected = zeros(eltype(l.weight),5,7,1,1)
expected[2:end-1,4,1,1] = l.weight
@test expected == l(data)
l = Conv((1,3), 1=>1)
expected = zeros(eltype(l.weight),7,5,1,1)
expected[4,2:end-1,1,1] = l.weight
@test expected == l(data)
@test begin
# we test that the next expression does not throw
randn(Float32, 10,10,1,1) |> Conv((6,1), 1=>1, Flux.σ)
true
end
end

View File

@ -25,6 +25,16 @@ trainmode(f, x...) = forward(f, x...)[1]
@test count(a->a == 0, y) > 50 @test count(a->a == 0, y) > 50
y = m(x) y = m(x)
@test count(a->a == 0, y) == 0 @test count(a->a == 0, y) == 0
x = rand(100, 50)
m = Dropout(0.5, dims = 2)
y = m(x)
c = map(i->count(a->a==0, @view y[i, :]), 1:100)
@test minimum(c) == maximum(c)
m = Dropout(0.5, dims = 1)
y = m(x)
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
@test minimum(c) == maximum(c)
end end
@testset "BatchNorm" begin @testset "BatchNorm" begin