AvgPool shim
This commit is contained in:
parent
5d919175fc
commit
2a57150bce
@ -1,4 +1,4 @@
|
|||||||
export Conv2D, MaxPool, Reshape
|
export Conv2D, MaxPool, AvgPool, Reshape
|
||||||
|
|
||||||
type Conv2D <: Model
|
type Conv2D <: Model
|
||||||
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
|
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
|
||||||
@ -14,22 +14,26 @@ shape(c::Conv2D, in::Dims{2}) =
|
|||||||
shape(c::Conv2D, in::Dims{3}) =
|
shape(c::Conv2D, in::Dims{3}) =
|
||||||
shape(c, (in[1],in[2]))
|
shape(c, (in[1],in[2]))
|
||||||
|
|
||||||
type MaxPool <: Model
|
for Pool in :[MaxPool, AvgPool].args
|
||||||
size::Dims{2}
|
@eval begin
|
||||||
stride::Dims{2}
|
type $Pool <: Model
|
||||||
|
size::Dims{2}
|
||||||
|
stride::Dims{2}
|
||||||
|
end
|
||||||
|
|
||||||
|
$Pool(size; stride = (1,1)) =
|
||||||
|
$Pool(size, stride)
|
||||||
|
|
||||||
|
shape(c::$Pool, in::Dims{2}) =
|
||||||
|
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
|
||||||
|
|
||||||
|
shape(c::$Pool, in::Dims{3}) =
|
||||||
|
(shape(c, (in[1],in[2]))..., in[3])
|
||||||
|
|
||||||
|
shape(c::$Pool) = nothing
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
MaxPool(size; stride = (1,1)) =
|
|
||||||
MaxPool(size, stride)
|
|
||||||
|
|
||||||
shape(c::MaxPool, in::Dims{2}) =
|
|
||||||
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
|
|
||||||
|
|
||||||
shape(c::MaxPool, in::Dims{3}) =
|
|
||||||
(shape(c, (in[1],in[2]))..., in[3])
|
|
||||||
|
|
||||||
shape(c::MaxPool, in) = throw(ShapeError(c, in))
|
|
||||||
|
|
||||||
immutable Reshape{N}
|
immutable Reshape{N}
|
||||||
dims::Dims{N}
|
dims::Dims{N}
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user