870: Fix printing of SkipConnection r=MikeInnes a=mcabbott

Before:
```
julia> SkipConnection(Dense(2,2),+)
SkipConnection(Error showing value of type SkipConnection:
ERROR: MethodError: no method matching iterate(::Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}})

julia> SkipConnection(Chain(Dense(2,3), Dense(3,2), LayerNorm(2)),+)
SkipConnection(Dense(2, 3), Dense(3, 2), LayerNorm(2))

julia> SkipConnection(Dense(2, 3), Dense(3, 2), LayerNorm(2))
ERROR: MethodError: no method matching SkipConnection(::Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}, ::Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}, ::LayerNorm{TrackedArray{…,Array{Float32,1}}})
```
After:
```
julia> SkipConnection(Dense(2,2),+)
SkipConnection(Dense(2, 2), +)

julia> SkipConnection(Chain(Dense(2,3), Dense(3,2), LayerNorm(2)),+)
SkipConnection(Chain(Dense(2, 3), Dense(3, 2), LayerNorm(2)), +)

julia> SkipConnection(Dense(2,2), (a,b) -> a .+ b./2)
SkipConnection(Dense(2, 2), #9)
```

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
This commit is contained in:
bors[bot] 2019-09-25 14:09:28 +00:00 committed by GitHub
commit 12bc06136d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 12 deletions

View File

@ -190,17 +190,22 @@ function (mo::Maxout)(input::AbstractArray)
end
"""
SkipConnection(layers...)
SkipConnection(layers, connection)
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
and a shortcut connection linking the input to the block to the
output through a user-supplied callable.
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
plus a shortcut connection. The connection function will combine the result of the layers
with the original input, to give the final output.
`SkipConnection` requires the output dimension to be the same as the input.
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
and requires the output of the layers to be the same shape as the input.
Here is a more complicated example:
```
m = Conv((3,3), 4=>7, pad=(1,1))
x = ones(5,5,4,10);
size(m(x)) == (5, 5, 7, 10)
A 'ResNet'-type skip-connection with identity shortcut would simply be
```julia
SkipConnection(layer, (a,b) -> a + b)
sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3))
size(sm(x)) == (5, 5, 11, 10)
```
"""
struct SkipConnection
@ -211,12 +216,9 @@ end
@functor SkipConnection
function (skip::SkipConnection)(input)
#We apply the layers to the input and return the result of the application of the layers and the original input
skip.connection(skip.layers(input), input)
end
function Base.show(io::IO, b::SkipConnection)
print(io, "SkipConnection(")
join(io, b.layers, ", ")
print(io, ")")
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
end