Merge #446
446: Added the SkipConnection layer and constructor r=MikeInnes a=bhvieira I added a DenseBlock constructor, which allows one to train DenseNets (you can train ResNets and MixNets with this as well, only need change the connection, which is concatenation for DenseNets). Disclaimer: I created the block for a 3D U-Net, so the assumption here is that whatever layer is inside the block, its output has the same spatial dimension (i.e. all array dimensions excluding the channel and minibatch dimensions) as the input, otherwise the connection wouldn't match. I'm not sure this matches the topology of every DenseNet there is out there, but I suppose this is a good starting point. No tests yet, will add them as the PR evolve. I'm open to suggestions! :) Co-authored-by: Bruno Hebling Vieira <bruno.hebling.vieira@usp.br> Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
This commit is contained in:
commit
1902c0e7c5
1
NEWS.md
1
NEWS.md
@ -1,5 +1,6 @@
|
||||
# v0.9.0
|
||||
* [Depthwise convolutional layer API changes](https://github.com/FluxML/Flux.jl/pull/756) from `in => mult` channel specification to `in => out` channel specification, and deprecates implicit `out` constructor.
|
||||
* New [SkipConnection](https://github.com/FluxML/Flux.jl/pull/446), which can be used to train residual neural network architectures.
|
||||
|
||||
# v0.8.0
|
||||
|
||||
|
@ -37,6 +37,7 @@ But in contrast to the layers described in the other sections are not readily gr
|
||||
|
||||
```@docs
|
||||
Maxout
|
||||
SkipConnection
|
||||
```
|
||||
|
||||
## Activation Functions
|
||||
|
@ -8,6 +8,7 @@ using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||
SkipConnection,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
@ -189,3 +189,36 @@ end
|
||||
function (mo::Maxout)(input::AbstractArray)
|
||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||
end
|
||||
|
||||
"""
|
||||
SkipConnection(layers...)
|
||||
|
||||
Creates a Skip Connection, which constitutes of a layer or Chain of consecutive layers
|
||||
and a shortcut connection linking the input to the block to the
|
||||
output through a user-supplied callable.
|
||||
|
||||
`SkipConnection` requires the output dimension to be the same as the input.
|
||||
|
||||
A 'ResNet'-type skip-connection with identity shortcut would simply be
|
||||
```julia
|
||||
SkipConnection(layer, (a,b) -> a + b)
|
||||
```
|
||||
"""
|
||||
|
||||
struct SkipConnection
|
||||
layers
|
||||
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
|
||||
end
|
||||
|
||||
@treelike SkipConnection
|
||||
|
||||
function (skip::SkipConnection)(input)
|
||||
#We apply the layers to the input and return the result of the application of the layers and the original input
|
||||
skip.connection(skip.layers(input), input)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, b::SkipConnection)
|
||||
print(io, "SkipConnection(")
|
||||
join(io, b.layers, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
@ -72,4 +72,16 @@ import Flux: activations
|
||||
@test length(ps) == 8 #4 alts, each with weight and bias
|
||||
end
|
||||
end
|
||||
|
||||
@testset "SkipConnection" begin
|
||||
@testset "zero sum" begin
|
||||
input = randn(10, 10, 10, 10)
|
||||
@test SkipConnection(x -> zeros(size(x)), (a,b) -> a + b)(input) == input
|
||||
end
|
||||
|
||||
@testset "concat size" begin
|
||||
input = randn(10, 2)
|
||||
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user