parent
82e16a5b29
commit
746e3310f1
|
@ -20,7 +20,7 @@ GlobalMeanPool
|
||||||
DepthwiseConv
|
DepthwiseConv
|
||||||
ConvTranspose
|
ConvTranspose
|
||||||
CrossCor
|
CrossCor
|
||||||
Flatten
|
flatten
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recurrent Layers
|
## Recurrent Layers
|
||||||
|
|
|
@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||||
export gradient
|
export gradient
|
||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose,
|
||||||
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Flatten,
|
GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten,
|
||||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
|
||||||
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
SkipConnection, params, fmap, cpu, gpu, f32, f64, testmode!, trainmode!
|
||||||
|
|
||||||
|
|
|
@ -416,29 +416,3 @@ function Base.show(io::IO, m::MeanPool)
|
||||||
end
|
end
|
||||||
|
|
||||||
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
|
||||||
|
|
||||||
"""
|
|
||||||
Flatten()
|
|
||||||
|
|
||||||
Flattening layer.
|
|
||||||
|
|
||||||
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
|
|
||||||
by linearizing all values for each element in the batch.
|
|
||||||
"""
|
|
||||||
struct Flatten{F}
|
|
||||||
σ::F
|
|
||||||
function Flatten(σ::F = identity) where {F}
|
|
||||||
return new{F}(σ)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function (f::Flatten)(x::AbstractArray)
|
|
||||||
σ = f.σ
|
|
||||||
σ(flatten(x))
|
|
||||||
end
|
|
||||||
|
|
||||||
function Base.show(io::IO, f::Flatten)
|
|
||||||
print(io, "Flatten(")
|
|
||||||
f.σ == identity || print(io, f.σ)
|
|
||||||
print(io, ")")
|
|
||||||
end
|
|
||||||
|
|
|
@ -12,8 +12,6 @@ using Flux: gradient
|
||||||
@test mp(x) == maxpool(x, PoolDims(x, 2))
|
@test mp(x) == maxpool(x, PoolDims(x, 2))
|
||||||
mp = MeanPool((2, 2))
|
mp = MeanPool((2, 2))
|
||||||
@test mp(x) == meanpool(x, PoolDims(x, 2))
|
@test mp(x) == meanpool(x, PoolDims(x, 2))
|
||||||
f = Flatten()
|
|
||||||
@test size(f(x)) == (300, 2)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "CNN" begin
|
@testset "CNN" begin
|
||||||
|
|
Loading…
Reference in New Issue