removed Flatten struct

updated documentation
This commit is contained in:
Garben Tanghe 2020-02-27 12:44:17 +01:00
parent 82e16a5b29
commit 746e3310f1
4 changed files with 2 additions and 30 deletions

View File

@ -20,7 +20,7 @@ GlobalMeanPool
DepthwiseConv DepthwiseConv
ConvTranspose ConvTranspose
CrossCor CrossCor
Flatten flatten
``` ```
## Recurrent Layers ## Recurrent Layers

View File

@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose,
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Flatten, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode! SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!

View File

@ -416,29 +416,3 @@ function Base.show(io::IO, m::MeanPool)
end end
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
"""
Flatten()
Flattening layer.
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
by linearizing all values for each element in the batch.
"""
struct Flatten{F}
σ::F
function Flatten(σ::F = identity) where {F}
return new{F}(σ)
end
end
function (f::Flatten)(x::AbstractArray)
σ = f.σ
σ(flatten(x))
end
function Base.show(io::IO, f::Flatten)
print(io, "Flatten(")
f.σ == identity || print(io, f.σ)
print(io, ")")
end

View File

@ -12,8 +12,6 @@ using Flux: gradient
@test mp(x) == maxpool(x, PoolDims(x, 2)) @test mp(x) == maxpool(x, PoolDims(x, 2))
mp = MeanPool((2, 2)) mp = MeanPool((2, 2))
@test mp(x) == meanpool(x, PoolDims(x, 2)) @test mp(x) == meanpool(x, PoolDims(x, 2))
f = Flatten()
@test size(f(x)) == (300, 2)
end end
@testset "CNN" begin @testset "CNN" begin