From 746e3310f18485c0c30c9975c71c88d53d00fe26 Mon Sep 17 00:00:00 2001 From: Garben Tanghe Date: Thu, 27 Feb 2020 12:44:17 +0100 Subject: [PATCH] removed Flatten struct updated documentation --- docs/src/models/layers.md | 2 +- src/Flux.jl | 2 +- src/layers/conv.jl | 26 -------------------------- test/layers/conv.jl | 2 -- 4 files changed, 2 insertions(+), 30 deletions(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 5f12d41a..2b5c1591 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -20,7 +20,7 @@ GlobalMeanPool DepthwiseConv ConvTranspose CrossCor -Flatten +flatten ``` ## Recurrent Layers diff --git a/src/Flux.jl b/src/Flux.jl index 725abfa7..f973dc4c 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, - GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Flatten, + GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode! diff --git a/src/layers/conv.jl b/src/layers/conv.jl index faca0895..742091a6 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -416,29 +416,3 @@ function Base.show(io::IO, m::MeanPool) end outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) - -""" - Flatten() - -Flattening layer. - -Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output, -by linearizing all values for each element in the batch. -""" -struct Flatten{F} - σ::F - function Flatten(σ::F = identity) where {F} - return new{F}(σ) - end -end - -function (f::Flatten)(x::AbstractArray) - σ = f.σ - σ(flatten(x)) -end - -function Base.show(io::IO, f::Flatten) - print(io, "Flatten(") - f.σ == identity || print(io, f.σ) - print(io, ")") -end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 60e1898d..e7b3963d 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -12,8 +12,6 @@ using Flux: gradient @test mp(x) == maxpool(x, PoolDims(x, 2)) mp = MeanPool((2, 2)) @test mp(x) == meanpool(x, PoolDims(x, 2)) - f = Flatten() - @test size(f(x)) == (300, 2) end @testset "CNN" begin