reshape layer
This commit is contained in:
parent
438dc9d40a
commit
a56af5d16e
@ -36,6 +36,7 @@ graph(::typeof(tanh), x) = tanh(x)
|
||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
||||
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
|
||||
|
||||
graph(::Input, x) = x
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
export Conv2D, MaxPool
|
||||
export Conv2D, MaxPool, Reshape
|
||||
|
||||
type Conv2D <: Model
|
||||
filter::Param{Array{Float32,4}} # [height, width, inchans, outchans]
|
||||
@ -29,3 +29,16 @@ shape(c::MaxPool, in::Dims{3}) =
|
||||
(shape(c, (in[1],in[2]))..., in[3])
|
||||
|
||||
shape(c::MaxPool, in) = throw(ShapeError(c, in))
|
||||
|
||||
immutable Reshape{N}
|
||||
dims::Dims{N}
|
||||
end
|
||||
|
||||
Reshape(dims::Integer...) = Reshape(dims)
|
||||
|
||||
function shape(r::Reshape, dims)
|
||||
prod(dims) == prod(r.dims) || throw(ShapeError(r, dims))
|
||||
return r.dims
|
||||
end
|
||||
|
||||
shape(r::Reshape, ::Void) = r.dims
|
||||
|
Loading…
Reference in New Issue
Block a user