split up Flatten layer to use the flatten function
This commit is contained in:
parent
3e14bd878c
commit
82e16a5b29
@ -425,12 +425,20 @@ Flattening layer.
|
|||||||
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
|
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
|
||||||
by linearizing all values for each element in the batch.
|
by linearizing all values for each element in the batch.
|
||||||
"""
|
"""
|
||||||
struct Flatten end
|
struct Flatten{F}
|
||||||
|
σ::F
|
||||||
|
function Flatten(σ::F = identity) where {F}
|
||||||
|
return new{F}(σ)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function (f::Flatten)(x)
|
function (f::Flatten)(x::AbstractArray)
|
||||||
return reshape(x, :, size(x)[end])
|
σ = f.σ
|
||||||
|
σ(flatten(x))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, f::Flatten)
|
function Base.show(io::IO, f::Flatten)
|
||||||
print(io, "Flatten()")
|
print(io, "Flatten(")
|
||||||
|
f.σ == identity || print(io, f.σ)
|
||||||
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
@ -200,3 +200,13 @@ Returns `1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)
|
|||||||
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
|
[Tversky loss function for image segmentation using 3D fully convolutional deep networks](https://arxiv.org/pdf/1706.05721.pdf)
|
||||||
"""
|
"""
|
||||||
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
tversky_loss(ŷ, y; β=eltype(ŷ)(0.7)) = 1 - (sum(y .* ŷ) + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
|
||||||
|
|
||||||
|
"""
|
||||||
|
flatten(x::AbstractArray)
|
||||||
|
|
||||||
|
Transforms (w,h,c,b)-shaped input into (w*h*c,b)-shaped output,
|
||||||
|
by linearizing all values for each element in the batch.
|
||||||
|
"""
|
||||||
|
function flatten(x::AbstractArray)
|
||||||
|
return reshape(x, :, size(x)[end])
|
||||||
|
end
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using Test
|
using Test
|
||||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||||
σ, binarycrossentropy, logitbinarycrossentropy
|
σ, binarycrossentropy, logitbinarycrossentropy, flatten
|
||||||
|
|
||||||
const ϵ = 1e-7
|
const ϵ = 1e-7
|
||||||
|
|
||||||
@ -116,3 +116,10 @@ const ϵ = 1e-7
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "helpers" begin
|
||||||
|
@testset "flatten" begin
|
||||||
|
x = randn(Float32, 10, 10, 3, 2)
|
||||||
|
@test size(flatten(x)) == (300, 2)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user