fix conv2d shape inference
This commit is contained in:
parent
4961bf72af
commit
438dc9d40a
@ -1,7 +1,7 @@
|
|||||||
export Conv2D, MaxPool
|
export Conv2D, MaxPool
|
||||||
|
|
||||||
type Conv2D <: Model
|
type Conv2D <: Model
|
||||||
filter::Param{Array{Float32,4}} # [height, width, outchans, inchans]
|
filter::Param{Array{Float32,4}} # [height, width, inchans, outchans]
|
||||||
stride::Dims{2}
|
stride::Dims{2}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -9,7 +9,7 @@ Conv2D(size; in = 1, out = 1, stride = (1,1), init = initn) =
|
|||||||
Conv2D(param(initn(size..., in, out)), stride)
|
Conv2D(param(initn(size..., in, out)), stride)
|
||||||
|
|
||||||
shape(c::Conv2D, in::Dims{2}) =
|
shape(c::Conv2D, in::Dims{2}) =
|
||||||
(map(i -> (in[i]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 3))
|
(map(i -> (in[i]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4))
|
||||||
|
|
||||||
shape(c::Conv2D, in::Dims{3}) =
|
shape(c::Conv2D, in::Dims{3}) =
|
||||||
shape(c, (in[1],in[2]))
|
shape(c, (in[1],in[2]))
|
||||||
|
Loading…
Reference in New Issue
Block a user