From e7d76b8423818c5a165e388dd3b090cc5bf42cbb Mon Sep 17 00:00:00 2001 From: Bruno Hebling Vieira Date: Sat, 20 Oct 2018 16:36:16 -0300 Subject: [PATCH] Added the SkipConnection layer and constructor Added missing export Corrected channel placement Dimension 4 cannot be assumed to always be the Channel dimension Deprecation of `treelike` Code now makes use of `@treelike` macro instead of the deprecated `treelike` function (it worked on my end because I'm on Julia 0.7, while Julia 1.0 deprecated stuff) Update basic.jl Renaming to SkipConnection * Update Flux.jl * Update basic.jl Updated `SkipConnection` with a `connection` field I'm pretty sure I broke something now, but this PR should follow along these lines `cat` needs special treatment (the user can declare his own `concatenate` connection, but I foresee it's going to be used often so we can simply define special treatment) Forgot to remove some rebasing text Forgot to remove some more rebasing text Removed local copy and default cat method from the function calls Adjusted some more types for inference, could improve on this as well Re-placed some left-over spaces --- src/Flux.jl | 2 +- src/layers/basic.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index eccdd6a7..bb378d23 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,7 +6,7 @@ using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, +export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, SkipConnection, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, params, mapleaves, cpu, gpu, f32, f64 diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e640bb24..25107120 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -189,3 +189,33 @@ end function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end + + +""" + SkipConnection(layers...) + + +Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers +and a shortcut connection linking the input to the block to the +output through a user-supplied function. + +`SkipConnection` requires the output dimension to be the same as the input. +""" + +struct SkipConnection + layers + connection::Function #user can pass arbitrary connections here, such as `.+` +end + +@treelike SkipConnection + +function (skip::SkipConnection)(input::AbstractArray) + #We apply the layers to the input and return the result of the application of the layers and the original input + skip.connection(skip.layers(input), input) +end + +function Base.show(io::IO, b::SkipConnection) + print(io, "SkipConnection(") + join(io, b.layers, ", ") + print(io, ")") +end