Added the SkipConnection layer and constructor
Added missing export Corrected channel placement Dimension 4 cannot be assumed to always be the Channel dimension Deprecation of `treelike` Code now makes use of `@treelike` macro instead of the deprecated `treelike` function (it worked on my end because I'm on Julia 0.7, while Julia 1.0 deprecated stuff) Update basic.jl Renaming to SkipConnection * Update Flux.jl * Update basic.jl Updated `SkipConnection` with a `connection` field I'm pretty sure I broke something now, but this PR should follow along these lines `cat` needs special treatment (the user can declare his own `concatenate` connection, but I foresee it's going to be used often so we can simply define special treatment) Forgot to remove some rebasing text Forgot to remove some more rebasing text Removed local copy and default cat method from the function calls Adjusted some more types for inference, could improve on this as well Re-placed some left-over spaces
This commit is contained in:
parent
308b199bd0
commit
e7d76b8423
@ -6,7 +6,7 @@ using Base: tail
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, SkipConnection, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
|
@ -189,3 +189,33 @@ end
|
||||
function (mo::Maxout)(input::AbstractArray)
|
||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
SkipConnection(layers...)
|
||||
|
||||
|
||||
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
|
||||
and a shortcut connection linking the input to the block to the
|
||||
output through a user-supplied function.
|
||||
|
||||
`SkipConnection` requires the output dimension to be the same as the input.
|
||||
"""
|
||||
|
||||
struct SkipConnection
|
||||
layers
|
||||
connection::Function #user can pass arbitrary connections here, such as `.+`
|
||||
end
|
||||
|
||||
@treelike SkipConnection
|
||||
|
||||
function (skip::SkipConnection)(input::AbstractArray)
|
||||
#We apply the layers to the input and return the result of the application of the layers and the original input
|
||||
skip.connection(skip.layers(input), input)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, b::SkipConnection)
|
||||
print(io, "SkipConnection(")
|
||||
join(io, b.layers, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user