2016-10-10 22:48:16 +00:00
|
|
|
export Conv2D, MaxPool, Reshape
|
2016-09-06 17:03:39 +00:00
|
|
|
|
2016-10-04 21:23:26 +00:00
|
|
|
type Conv2D <: Model
|
2017-02-20 21:50:01 +00:00
|
|
|
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
|
2016-09-06 17:03:39 +00:00
|
|
|
stride::Dims{2}
|
|
|
|
end
|
|
|
|
|
2016-10-04 21:23:26 +00:00
|
|
|
Conv2D(size; in = 1, out = 1, stride = (1,1), init = initn) =
|
|
|
|
Conv2D(param(initn(size..., in, out)), stride)
|
2016-09-06 17:03:39 +00:00
|
|
|
|
2016-10-04 21:23:26 +00:00
|
|
|
shape(c::Conv2D, in::Dims{2}) =
|
2016-10-10 22:20:40 +00:00
|
|
|
(map(i -> (in[i]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4))
|
2016-09-06 17:03:39 +00:00
|
|
|
|
2016-10-04 21:23:26 +00:00
|
|
|
shape(c::Conv2D, in::Dims{3}) =
|
2016-09-06 17:03:39 +00:00
|
|
|
shape(c, (in[1],in[2]))
|
|
|
|
|
|
|
|
type MaxPool <: Model
|
|
|
|
size::Dims{2}
|
|
|
|
stride::Dims{2}
|
|
|
|
end
|
|
|
|
|
|
|
|
MaxPool(size; stride = (1,1)) =
|
|
|
|
MaxPool(size, stride)
|
|
|
|
|
|
|
|
shape(c::MaxPool, in::Dims{2}) =
|
|
|
|
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
|
|
|
|
|
|
|
|
shape(c::MaxPool, in::Dims{3}) =
|
|
|
|
(shape(c, (in[1],in[2]))..., in[3])
|
|
|
|
|
|
|
|
shape(c::MaxPool, in) = throw(ShapeError(c, in))
|
2016-10-10 22:48:16 +00:00
|
|
|
|
|
|
|
immutable Reshape{N}
|
|
|
|
dims::Dims{N}
|
|
|
|
end
|
|
|
|
|
|
|
|
Reshape(dims::Integer...) = Reshape(dims)
|
|
|
|
|
|
|
|
function shape(r::Reshape, dims)
|
|
|
|
prod(dims) == prod(r.dims) || throw(ShapeError(r, dims))
|
|
|
|
return r.dims
|
|
|
|
end
|
|
|
|
|
|
|
|
shape(r::Reshape, ::Void) = r.dims
|