From e7d76b8423818c5a165e388dd3b090cc5bf42cbb Mon Sep 17 00:00:00 2001 From: Bruno Hebling Vieira Date: Sat, 20 Oct 2018 16:36:16 -0300 Subject: [PATCH 1/4] Added the SkipConnection layer and constructor Added missing export Corrected channel placement Dimension 4 cannot be assumed to always be the Channel dimension Deprecation of `treelike` Code now makes use of `@treelike` macro instead of the deprecated `treelike` function (it worked on my end because I'm on Julia 0.7, while Julia 1.0 deprecated stuff) Update basic.jl Renaming to SkipConnection * Update Flux.jl * Update basic.jl Updated `SkipConnection` with a `connection` field I'm pretty sure I broke something now, but this PR should follow along these lines `cat` needs special treatment (the user can declare his own `concatenate` connection, but I foresee it's going to be used often so we can simply define special treatment) Forgot to remove some rebasing text Forgot to remove some more rebasing text Removed local copy and default cat method from the function calls Adjusted some more types for inference, could improve on this as well Re-placed some left-over spaces --- src/Flux.jl | 2 +- src/layers/basic.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index eccdd6a7..bb378d23 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,7 +6,7 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, +export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, SkipConnection, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, params, mapleaves, cpu, gpu, f32, f64 diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e640bb24..25107120 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -189,3 +189,33 @@ end function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end + + +""" + SkipConnection(layers...) + + +Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers +and a shortcut connection linking the input to the block to the +output through a user-supplied function. + +`SkipConnection` requires the output dimension to be the same as the input. +""" + +struct SkipConnection + layers + connection::Function #user can pass arbitrary connections here, such as `.+` +end + +@treelike SkipConnection + +function (skip::SkipConnection)(input::AbstractArray) + #We apply the layers to the input and return the result of the application of the layers and the original input + skip.connection(skip.layers(input), input) +end + +function Base.show(io::IO, b::SkipConnection) + print(io, "SkipConnection(") + join(io, b.layers, ", ") + print(io, ")") +end From c5fc2fb9a3dd34e4cc99f7fa1e0ee530f0fd0a29 Mon Sep 17 00:00:00 2001 From: Bruno Hebling Vieira Date: Mon, 13 May 2019 10:21:25 -0300 Subject: [PATCH 2/4] Added tests --- test/layers/basic.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 4deb545f..cbe250fc 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -72,4 +72,16 @@ import Flux: activations @test length(ps) == 8 #4 alts, each with weight and bias end end + + @testset "SkipConnection" begin + @testset "zero sum" begin + input = randn(10, 10, 10, 10) + @test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input + end + + @testset "concat size" begin + input = randn(10, 2) + @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) + end + end end From 796a2957c94d161755df23a5b49ddd7055bee811 Mon Sep 17 00:00:00 2001 From: Bruno Hebling Vieira Date: Mon, 13 May 2019 13:47:46 -0300 Subject: [PATCH 3/4] Added news and removed type annotation from SkipConnection structure --- NEWS.md | 1 + src/layers/basic.jl | 13 ++++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/NEWS.md b/NEWS.md index 8ac937ce..27412e26 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # v0.9.0 * [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor. +* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures. # v0.8.0 diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 25107120..12d4e2e3 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -190,26 +190,29 @@ function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end - """ SkipConnection(layers...) - Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers and a shortcut connection linking the input to the block to the -output through a user-supplied function. +output through a user-supplied callable. `SkipConnection` requires the output dimension to be the same as the input. + +A 'ResNet'-type skip-connection with identity shortcut would simply be +```julia + SkipConnection(layer, (a,b) -> a + b) +``` """ struct SkipConnection layers - connection::Function #user can pass arbitrary connections here, such as `.+` + connection #user can pass arbitrary connections here, such as (a,b) -> a + b end @treelike SkipConnection -function (skip::SkipConnection)(input::AbstractArray) +function (skip::SkipConnection)(input) #We apply the layers to the input and return the result of the application of the layers and the original input skip.connection(skip.layers(input), input) end From 6b3cd825b97b0b95bdd634c2c6a0d6c7e78f09bc Mon Sep 17 00:00:00 2001 From: Bruno Hebling Vieira Date: Mon, 13 May 2019 16:43:14 -0300 Subject: [PATCH 4/4] Added SkipConnection to docs tentatively in Other General Purporse Layers --- docs/src/models/layers.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index ec45c31e..23b3b2a2 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -36,6 +36,7 @@ But in contrast to the layers described in the other sections are not readily gr ```@docs Maxout +SkipConnection ``` ## Activation Functions