reshape layer
This commit is contained in:
parent
438dc9d40a
commit
a56af5d16e
@ -35,7 +35,8 @@ graph(::typeof(tanh), x) = tanh(x)
|
|||||||
|
|
||||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||||
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x),Int32(-1)]))
|
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
||||||
|
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
|
||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
export Conv2D, MaxPool
|
export Conv2D, MaxPool, Reshape
|
||||||
|
|
||||||
type Conv2D <: Model
|
type Conv2D <: Model
|
||||||
filter::Param{Array{Float32,4}} # [height, width, inchans, outchans]
|
filter::Param{Array{Float32,4}} # [height, width, inchans, outchans]
|
||||||
@ -29,3 +29,16 @@ shape(c::MaxPool, in::Dims{3}) =
|
|||||||
(shape(c, (in[1],in[2]))..., in[3])
|
(shape(c, (in[1],in[2]))..., in[3])
|
||||||
|
|
||||||
shape(c::MaxPool, in) = throw(ShapeError(c, in))
|
shape(c::MaxPool, in) = throw(ShapeError(c, in))
|
||||||
|
|
||||||
|
immutable Reshape{N}
|
||||||
|
dims::Dims{N}
|
||||||
|
end
|
||||||
|
|
||||||
|
Reshape(dims::Integer...) = Reshape(dims)
|
||||||
|
|
||||||
|
function shape(r::Reshape, dims)
|
||||||
|
prod(dims) == prod(r.dims) || throw(ShapeError(r, dims))
|
||||||
|
return r.dims
|
||||||
|
end
|
||||||
|
|
||||||
|
shape(r::Reshape, ::Void) = r.dims
|
||||||
|
Loading…
Reference in New Issue
Block a user