diff --git a/src/backend/tensorflow/tensorflow.jl b/src/backend/tensorflow/tensorflow.jl index 153f0ca2..9aabb39a 100644 --- a/src/backend/tensorflow/tensorflow.jl +++ b/src/backend/tensorflow/tensorflow.jl @@ -35,7 +35,8 @@ graph(::typeof(tanh), x) = tanh(x) # reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79 batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1])) -graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x),Int32(-1)])) +graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)])) +graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...])) graph(::Input, x) = x diff --git a/src/layers/shims.jl b/src/layers/shims.jl index 6144f3f5..8ffc4e1b 100644 --- a/src/layers/shims.jl +++ b/src/layers/shims.jl @@ -1,4 +1,4 @@ -export Conv2D, MaxPool +export Conv2D, MaxPool, Reshape type Conv2D <: Model filter::Param{Array{Float32,4}} # [height, width, inchans, outchans] @@ -29,3 +29,16 @@ shape(c::MaxPool, in::Dims{3}) = (shape(c, (in[1],in[2]))..., in[3]) shape(c::MaxPool, in) = throw(ShapeError(c, in)) + +immutable Reshape{N} + dims::Dims{N} +end + +Reshape(dims::Integer...) = Reshape(dims) + +function shape(r::Reshape, dims) + prod(dims) == prod(r.dims) || throw(ShapeError(r, dims)) + return r.dims +end + +shape(r::Reshape, ::Void) = r.dims