diff --git a/src/layers/shims.jl b/src/layers/shims.jl index bad409b9..c227278b 100644 --- a/src/layers/shims.jl +++ b/src/layers/shims.jl @@ -1,4 +1,4 @@ -export Conv2D, MaxPool, Reshape +export Conv2D, MaxPool, AvgPool, Reshape type Conv2D <: Model filter::Param{Array{Float64,4}} # [height, width, inchans, outchans] @@ -14,22 +14,26 @@ shape(c::Conv2D, in::Dims{2}) = shape(c::Conv2D, in::Dims{3}) = shape(c, (in[1],in[2])) -type MaxPool <: Model - size::Dims{2} - stride::Dims{2} +for Pool in :[MaxPool, AvgPool].args + @eval begin + type $Pool <: Model + size::Dims{2} + stride::Dims{2} + end + + $Pool(size; stride = (1,1)) = + $Pool(size, stride) + + shape(c::$Pool, in::Dims{2}) = + map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2)) + + shape(c::$Pool, in::Dims{3}) = + (shape(c, (in[1],in[2]))..., in[3]) + + shape(c::$Pool) = nothing + end end -MaxPool(size; stride = (1,1)) = - MaxPool(size, stride) - -shape(c::MaxPool, in::Dims{2}) = - map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2)) - -shape(c::MaxPool, in::Dims{3}) = - (shape(c, (in[1],in[2]))..., in[3]) - -shape(c::MaxPool, in) = throw(ShapeError(c, in)) - immutable Reshape{N} dims::Dims{N} end