AvgPool shim

This commit is contained in:
Mike J Innes 2017-03-06 17:21:35 +00:00
parent 5d919175fc
commit 2a57150bce

View File

@ -1,4 +1,4 @@
export Conv2D, MaxPool, Reshape export Conv2D, MaxPool, AvgPool, Reshape
type Conv2D <: Model type Conv2D <: Model
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans] filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
@ -14,22 +14,26 @@ shape(c::Conv2D, in::Dims{2}) =
shape(c::Conv2D, in::Dims{3}) = shape(c::Conv2D, in::Dims{3}) =
shape(c, (in[1],in[2])) shape(c, (in[1],in[2]))
type MaxPool <: Model for Pool in :[MaxPool, AvgPool].args
size::Dims{2} @eval begin
stride::Dims{2} type $Pool <: Model
size::Dims{2}
stride::Dims{2}
end
$Pool(size; stride = (1,1)) =
$Pool(size, stride)
shape(c::$Pool, in::Dims{2}) =
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
shape(c::$Pool, in::Dims{3}) =
(shape(c, (in[1],in[2]))..., in[3])
shape(c::$Pool) = nothing
end
end end
MaxPool(size; stride = (1,1)) =
MaxPool(size, stride)
shape(c::MaxPool, in::Dims{2}) =
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
shape(c::MaxPool, in::Dims{3}) =
(shape(c, (in[1],in[2]))..., in[3])
shape(c::MaxPool, in) = throw(ShapeError(c, in))
immutable Reshape{N} immutable Reshape{N}
dims::Dims{N} dims::Dims{N}
end end