diff --git a/NEWS.md b/NEWS.md index f01f1259..a3586e83 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # v0.9.0 * [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor. +* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures. # v0.8.0 diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 3acb910d..f2bd8046 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -37,6 +37,7 @@ But in contrast to the layers described in the other sections are not readily gr ```@docs Maxout +SkipConnection ``` ## Activation Functions diff --git a/src/Flux.jl b/src/Flux.jl index a041a69a..94f586d9 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,8 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + SkipConnection, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e640bb24..12d4e2e3 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -189,3 +189,36 @@ end function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end + +""" + SkipConnection(layers...) + +Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers +and a shortcut connection linking the input to the block to the +output through a user-supplied callable. + +`SkipConnection` requires the output dimension to be the same as the input. + +A 'ResNet'-type skip-connection with identity shortcut would simply be +```julia + SkipConnection(layer, (a,b) -> a + b) +``` +""" + +struct SkipConnection + layers + connection #user can pass arbitrary connections here, such as (a,b) -> a + b +end + +@treelike SkipConnection + +function (skip::SkipConnection)(input) + #We apply the layers to the input and return the result of the application of the layers and the original input + skip.connection(skip.layers(input), input) +end + +function Base.show(io::IO, b::SkipConnection) + print(io, "SkipConnection(") + join(io, b.layers, ", ") + print(io, ")") +end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 4deb545f..cbe250fc 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -72,4 +72,16 @@ import Flux: activations @test length(ps) == 8 #4 alts, each with weight and bias end end + + @testset "SkipConnection" begin + @testset "zero sum" begin + input = randn(10, 10, 10, 10) + @test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input + end + + @testset "concat size" begin + input = randn(10, 2) + @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) + end + end end