Updated with all basic and conv layers outdims
This commit is contained in:
parent
b4ed16ad9c
commit
31dda0ce6c
@ -40,7 +40,7 @@ function Base.show(io::IO, c::Chain)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outdims(c::Chain, isize::Tuple)
|
outdims(c::Chain, isize)
|
||||||
|
|
||||||
Calculate the output dimensions given the input dimensions, `isize`.
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
|||||||
outdims(m, (10, 10)) == (6, 6)
|
outdims(m, (10, 10)) == (6, 6)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
outdims(c::Chain, isize::Tuple) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))
|
outdims(c::Chain, isize) = foldl(∘, map(l -> (x -> outdims(l, x)), c.layers))
|
||||||
|
|
||||||
# This is a temporary and naive implementation
|
# This is a temporary and naive implementation
|
||||||
# it might be replaced in the future for better performance
|
# it might be replaced in the future for better performance
|
||||||
@ -228,6 +228,18 @@ function (mo::Maxout)(input::AbstractArray)
|
|||||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(c::Maxout, isize)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = Maxout(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
|
||||||
|
outdims(m, (10, 10)) == (8, 8)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::Maxout, isize) = outdims(first(l.over))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
SkipConnection(layers, connection)
|
SkipConnection(layers, connection)
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
using NNlib: conv, ∇conv_data, depthwiseconv
|
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||||
|
|
||||||
_convoutdims(isize, ksize, ssize, pad) = Int.(floor.((isize .- ksize .+ 2 .* pad) ./ ssize .+ 1))
|
_convoutdims(isize, ksize, ssize, pad) = Int.(floor.((isize .- ksize .+ 2 .* pad) ./ ssize .+ 1))
|
||||||
|
_convtransoutdims(isize, ksize, ssize, pad) = Int.(ssize .* (isize .- 1) .+ ksize .- 2 .* pad))
|
||||||
|
|
||||||
expand(N, i::Tuple) = i
|
expand(N, i::Tuple) = i
|
||||||
expand(N, i::Integer) = ntuple(_ -> i, N)
|
expand(N, i::Integer) = ntuple(_ -> i, N)
|
||||||
@ -155,6 +156,18 @@ end
|
|||||||
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(l::ConvTranspose, isize::Tuple)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = ConvTranspose((3, 3), 3 => 16)
|
||||||
|
outdims(m, (8, 8)) == (10, 10)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize, size(l.weight)[1:N], l.stride, l.pad[1:N])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
DepthwiseConv(size, in=>out)
|
DepthwiseConv(size, in=>out)
|
||||||
DepthwiseConv(size, in=>out, relu)
|
DepthwiseConv(size, in=>out, relu)
|
||||||
@ -302,6 +315,18 @@ end
|
|||||||
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
a(T.(x))
|
a(T.(x))
|
||||||
|
|
||||||
|
"""
|
||||||
|
outdims(l::CrossCor, isize::Tuple)
|
||||||
|
|
||||||
|
Calculate the output dimensions given the input dimensions, `isize`.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = CrossCor((3, 3), 3 => 16)
|
||||||
|
outdims(m, (10, 10)) == (8, 8)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
outdims(l::CrossCor{N}, isize) where N = _convoutdims(isize, size(l.weight)[1:N], l.stride, l.pad[1:N])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MaxPool(k)
|
MaxPool(k)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user