split up Flatten layer to use the flatten function

This commit is contained in:
Garben Tanghe 2019-12-05 14:16:12 +01:00
parent 3e14bd878c
commit 82e16a5b29
3 changed files with 38 additions and 13 deletions

View File

@ -425,12 +425,20 @@ Flattening layer.
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
by linearizing all values for each element in the batch.
"""
struct Flatten end
struct Flatten{F}
σ::F
function Flatten(σ::F = identity) where {F}
return new{F}(σ)
end
end
function (f::Flatten)(x)
return reshape(x, :, size(x)[end])
function (f::Flatten)(x::AbstractArray)
σ = f.σ
σ(flatten(x))
end
function Base.show(io::IO, f::Flatten)
print(io, "Flatten()")
print(io, "Flatten(")
f.σ == identity || print(io, f.σ)
print(io, ")")
end

View File

@ -200,3 +200,13 @@ Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
"""
tversky_loss(, y; β=eltype()(0.7)) = 1 - (sum(y .* ) + 1) / (sum(y .* + β*(1 .- y) .* + (1 - β)*y .* (1 .- )) + 1)
"""
flatten(x::AbstractArray)
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
by linearizing all values for each element in the batch.
"""
function flatten(x::AbstractArray)
return reshape(x, :, size(x)[end])
end

View File

@ -1,6 +1,6 @@
using Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy
σ, binarycrossentropy, logitbinarycrossentropy, flatten
const ϵ = 1e-7
@ -116,3 +116,10 @@ const ϵ = 1e-7
end
end
end
@testset "helpers" begin
@testset "flatten" begin
x = randn(Float32, 10, 10, 3, 2)
@test size(flatten(x)) == (300, 2)
end
end