446: Added the SkipConnection layer and constructor r=MikeInnes a=bhvieira

I added a DenseBlock constructor, which allows one to train DenseNets (you can train ResNets and MixNets with this as well, only need change the connection, which is concatenation for DenseNets).

Disclaimer: I created the block for a 3D U-Net, so the assumption here is that whatever layer is inside the block, its output has the same spatial dimension (i.e. all array dimensions excluding the channel and minibatch dimensions) as the input, otherwise the connection wouldn't match. I'm not sure this matches the topology of every DenseNet there is out there, but I suppose this is a good starting point.

No tests yet, will add them as the PR evolve.

I'm open to suggestions! :)


Co-authored-by: Bruno Hebling Vieira <bruno.hebling.vieira@usp.br>
Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
This commit is contained in:
bors[bot] 2019-06-05 13:28:41 +00:00
commit 1902c0e7c5
5 changed files with 49 additions and 1 deletions

View File

@ -1,5 +1,6 @@
# v0.9.0
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
# v0.8.0

View File

@ -37,6 +37,7 @@ But in contrast to the layers described in the other sections are not readily gr
```@docs
Maxout
SkipConnection
```
## Activation Functions

View File

@ -8,6 +8,7 @@ using MacroTools: @forward
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection,
params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib

View File

@ -189,3 +189,36 @@ end
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end
"""
SkipConnection(layers...)
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
and a shortcut connection linking the input to the block to the
output through a user-supplied callable.
`SkipConnection` requires the output dimension to be the same as the input.
A 'ResNet'-type skip-connection with identity shortcut would simply be
```julia
SkipConnection(layer, (a,b) -> a + b)
```
"""
struct SkipConnection
layers
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end
@treelike SkipConnection
function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input
skip.connection(skip.layers(input), input)
end
function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(")
join(io, b.layers, ", ")
print(io, ")")
end

View File

@ -72,4 +72,16 @@ import Flux: activations
@test length(ps) == 8 #4 alts, each with weight and bias
end
end
@testset "SkipConnection" begin
@testset "zero sum" begin
input = randn(10, 10, 10, 10)
@test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input
end
@testset "concat size" begin
input = randn(10, 2)
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
end
end
end