diff --git a/src/Flux.jl b/src/Flux.jl index eccdd6a7..bb378d23 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,7 +6,7 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, +export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, SkipConnection, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, params, mapleaves, cpu, gpu, f32, f64 diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e640bb24..25107120 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -189,3 +189,33 @@ end function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end + + +""" + SkipConnection(layers...) + + +Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers +and a shortcut connection linking the input to the block to the +output through a user-supplied function. + +`SkipConnection` requires the output dimension to be the same as the input. +""" + +struct SkipConnection + layers + connection::Function #user can pass arbitrary connections here, such as `.+` +end + +@treelike SkipConnection + +function (skip::SkipConnection)(input::AbstractArray) + #We apply the layers to the input and return the result of the application of the layers and the original input + skip.connection(skip.layers(input), input) +end + +function Base.show(io::IO, b::SkipConnection) + print(io, "SkipConnection(") + join(io, b.layers, ", ") + print(io, ")") +end