This commit is contained in:
Michael Abbott 2019-09-25 15:18:40 +02:00
parent 2de84ce79f
commit 4245d9acad

View File

@ -192,17 +192,23 @@ end
"""
SkipConnection(layers, connection)
Creates a Skip Connection, which constitutes of a layer or `Chain` of consecutive layers
and a shortcut connection linking the input to the block to the
output through a user-supplied callable.
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
plus a shortcut connection. The connection function will combine the result of the layers
with the original input, to give the final output.
`SkipConnection` requires the output dimension to be the same as the input.
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
and requires the output of the layers to be the same shape as the input.
Here is a more complicated example:
```
m = Conv((3,3), 4=>7, pad=(1,1))
x = ones(5,5,4,10);
size(m(x)) == (5, 5, 7, 10)
A 'ResNet'-type skip-connection with identity shortcut would simply be
```julia
SkipConnection(layer, +)
sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3))
size(sm(x)) == (5, 5, 11, 10)
```
"""
function SkipConnection end
struct SkipConnection
layers
connection #user can pass arbitrary connections here, such as (a,b) -> a + b