Added outdims for some basic layers
This commit is contained in:
parent
9279d79e63
commit
b4ed16ad9c
@ -39,6 +39,17 @@ function Base.show(io::IO, c::Chain)
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(c::Chain, isize::Tuple)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
||||||
|
outdims(m, (10, 10)) == (6, 6)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(c::Chain, isize::Tuple) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))
|
||||||
|
|
||||||
# This is a temporary and naive implementation
|
# This is a temporary and naive implementation
|
||||||
# it might be replaced in the future for better performance
|
# it might be replaced in the future for better performance
|
||||||
@ -116,6 +127,19 @@ end
|
|||||||
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(l::Dense, isize)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = Dense(10, 5)
|
||||||
|
outdims(m, (5, 2)) == (5,)
|
||||||
|
outdims(m, (10,)) == (5,)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::Dense, isize) = (size(l.W)[2],)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Diagonal(in::Integer)
|
Diagonal(in::Integer)
|
||||||
|
|
||||||
@ -145,6 +169,17 @@ function Base.show(io::IO, l::Diagonal)
|
|||||||
print(io, "Diagonal(", length(l.α), ")")
|
print(io, "Diagonal(", length(l.α), ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(l::Diagonal, isize)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = Diagonal(10)
|
||||||
|
outdims(m, (10,)) == (10,)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::Diagonal, isize) = (length(l.α),)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Maxout(over)
|
Maxout(over)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
using NNlib: conv, ∇conv_data, depthwiseconv
|
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||||
|
|
||||||
|
_convoutdims(isize, ksize, ssize, pad) = Int.(floor.((isize .- ksize .+ 2 .* pad) ./ ssize .+ 1))
|
||||||
|
|
||||||
expand(N, i::Tuple) = i
|
expand(N, i::Tuple) = i
|
||||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
"""
|
"""
|
||||||
@ -68,6 +70,18 @@ end
|
|||||||
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(l::Conv, isize::Tuple)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = Conv((3, 3), 3 => 16)
|
||||||
|
outdims(m, (10, 10)) == (8, 8)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::Conv{N}, isize) where N = _convoutdims(isize, size(l.weight)[1:N], l.stride, l.pad[1:N])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ConvTranspose(size, in=>out)
|
ConvTranspose(size, in=>out)
|
||||||
ConvTranspose(size, in=>out, relu)
|
ConvTranspose(size, in=>out, relu)
|
||||||
@ -140,6 +154,7 @@ end
|
|||||||
|
|
||||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
DepthwiseConv(size, in=>out)
|
DepthwiseConv(size, in=>out)
|
||||||
DepthwiseConv(size, in=>out, relu)
|
DepthwiseConv(size, in=>out, relu)
|
||||||
@ -204,6 +219,18 @@ end
|
|||||||
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(l::DepthwiseConv, isize::Tuple)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = DepthwiseConv((3, 3), 3 => 16)
|
||||||
|
outdims(m, (10, 10)) == (8, 8)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::DepthwiseConv{N}, isize) where N = _convoutdims(isize, size(l.weight)[1:N], l.stride, l.pad[1:N])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
CrossCor(size, in=>out)
|
CrossCor(size, in=>out)
|
||||||
CrossCor(size, in=>out, relu)
|
CrossCor(size, in=>out, relu)
|
||||||
@ -304,6 +331,18 @@ function Base.show(io::IO, m::MaxPool)
|
|||||||
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
print(io, "MaxPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(l::MaxPool, isize::Tuple)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = MaxPool((2, 2))
|
||||||
|
outdims(m, (10, 10)) == (5, 5)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::MaxPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MeanPool(k)
|
MeanPool(k)
|
||||||
|
|
||||||
@ -331,3 +370,15 @@ end
|
|||||||
function Base.show(io::IO, m::MeanPool)
|
function Base.show(io::IO, m::MeanPool)
|
||||||
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(l::MeanPool, isize::Tuple)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = MeanPool((2, 2))
|
||||||
|
outdims(m, (10, 10)) == (5, 5)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::MeanPool{N}, isize) where N = _convoutdims(isize, l.weight, l.stride, l.pad[1:N])
|
Loading…
Reference in New Issue
Block a user