removed Flatten struct

updated documentation
This commit is contained in:
Garben Tanghe 2020-02-27 12:44:17 +01:00
parent 82e16a5b29
commit 746e3310f1
4 changed files with 2 additions and 30 deletions

View File

@ -20,7 +20,7 @@ GlobalMeanPool
DepthwiseConv
ConvTranspose
CrossCor
Flatten
flatten
```
## Recurrent Layers

View File

@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose,
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Flatten,
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!

View File

@ -416,29 +416,3 @@ function Base.show(io::IO, m::MeanPool)
end
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
"""
Flatten()
Flattening layer.
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
by linearizing all values for each element in the batch.
"""
struct Flatten{F}
σ::F
function Flatten(σ::F = identity) where {F}
return new{F}(σ)
end
end
function (f::Flatten)(x::AbstractArray)
σ = f.σ
σ(flatten(x))
end
function Base.show(io::IO, f::Flatten)
print(io, "Flatten(")
f.σ == identity || print(io, f.σ)
print(io, ")")
end

View File

@ -12,8 +12,6 @@ using Flux: gradient
@test mp(x) == maxpool(x, PoolDims(x, 2))
mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, PoolDims(x, 2))
f = Flatten()
@test size(f(x)) == (300, 2)
end
@testset "CNN" begin