Merge #870
870: Fix printing of SkipConnection r=MikeInnes a=mcabbott Before: ``` julia> SkipConnection(Dense(2,2),+) SkipConnection(Error showing value of type SkipConnection: ERROR: MethodError: no method matching iterate(::Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}) julia> SkipConnection(Chain(Dense(2,3), Dense(3,2), LayerNorm(2)),+) SkipConnection(Dense(2, 3), Dense(3, 2), LayerNorm(2)) julia> SkipConnection(Dense(2, 3), Dense(3, 2), LayerNorm(2)) ERROR: MethodError: no method matching SkipConnection(::Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}, ::Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}, ::LayerNorm{TrackedArray{…,Array{Float32,1}}}) ``` After: ``` julia> SkipConnection(Dense(2,2),+) SkipConnection(Dense(2, 2), +) julia> SkipConnection(Chain(Dense(2,3), Dense(3,2), LayerNorm(2)),+) SkipConnection(Chain(Dense(2, 3), Dense(3, 2), LayerNorm(2)), +) julia> SkipConnection(Dense(2,2), (a,b) -> a .+ b./2) SkipConnection(Dense(2, 2), #9) ``` Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
This commit is contained in:
commit
12bc06136d
|
@ -190,17 +190,22 @@ function (mo::Maxout)(input::AbstractArray)
|
|||
end
|
||||
|
||||
"""
|
||||
SkipConnection(layers...)
|
||||
SkipConnection(layers, connection)
|
||||
|
||||
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
|
||||
and a shortcut connection linking the input to the block to the
|
||||
output through a user-supplied callable.
|
||||
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
|
||||
plus a shortcut connection. The connection function will combine the result of the layers
|
||||
with the original input, to give the final output.
|
||||
|
||||
`SkipConnection` requires the output dimension to be the same as the input.
|
||||
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
|
||||
and requires the output of the layers to be the same shape as the input.
|
||||
Here is a more complicated example:
|
||||
```
|
||||
m = Conv((3,3), 4=>7, pad=(1,1))
|
||||
x = ones(5,5,4,10);
|
||||
size(m(x)) == (5, 5, 7, 10)
|
||||
|
||||
A 'ResNet'-type skip-connection with identity shortcut would simply be
|
||||
```julia
|
||||
SkipConnection(layer, (a,b) -> a + b)
|
||||
sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3))
|
||||
size(sm(x)) == (5, 5, 11, 10)
|
||||
```
|
||||
"""
|
||||
struct SkipConnection
|
||||
|
@ -211,12 +216,9 @@ end
|
|||
@functor SkipConnection
|
||||
|
||||
function (skip::SkipConnection)(input)
|
||||
#We apply the layers to the input and return the result of the application of the layers and the original input
|
||||
skip.connection(skip.layers(input), input)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, b::SkipConnection)
|
||||
print(io, "SkipConnection(")
|
||||
join(io, b.layers, ", ")
|
||||
print(io, ")")
|
||||
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue