Added the SkipConnection layer and constructor

Added missing export

Corrected channel placement

Dimension 4 cannot be assumed to always be the Channel dimension

Deprecation of `treelike`

Code now makes use of `@treelike` macro instead of the deprecated `treelike` function (it worked on my end because I'm on Julia 0.7, while Julia 1.0 deprecated stuff)

Update basic.jl

Renaming to SkipConnection

* Update Flux.jl

* Update basic.jl

Updated `SkipConnection` with a `connection` field

I'm pretty sure I broke something now, but this PR should follow along these lines `cat` needs special treatment (the user can declare his own `concatenate` connection, but I foresee it's going to be used often so we can simply define special treatment)

Forgot to remove some rebasing text

Forgot to remove some more rebasing text

Removed local copy and default cat method from the function calls

Adjusted some more types for inference, could improve on this as well

Re-placed some left-over spaces
This commit is contained in:
Bruno Hebling Vieira 2018-10-20 16:36:16 -03:00
parent 308b199bd0
commit e7d76b8423
2 changed files with 31 additions and 1 deletions

View File

@ -6,7 +6,7 @@ using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, SkipConnection, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64 params, mapleaves, cpu, gpu, f32, f64

View File

@ -189,3 +189,33 @@ end
function (mo::Maxout)(input::AbstractArray) function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end end
"""
SkipConnection(layers...)
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
and a shortcut connection linking the input to the block to the
output through a user-supplied function.
`SkipConnection` requires the output dimension to be the same as the input.
"""
struct SkipConnection
layers
connection::Function #user can pass arbitrary connections here, such as `.+`
end
@treelike SkipConnection
function (skip::SkipConnection)(input::AbstractArray)
#We apply the layers to the input and return the result of the application of the layers and the original input
skip.connection(skip.layers(input), input)
end
function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(")
join(io, b.layers, ", ")
print(io, ")")
end