reshape layer

This commit is contained in:
Mike J Innes 2016-10-10 23:48:16 +01:00
parent 438dc9d40a
commit a56af5d16e
2 changed files with 16 additions and 2 deletions

View File

@ -35,7 +35,8 @@ graph(::typeof(tanh), x) = tanh(x)
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79 # reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1])) batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x),Int32(-1)])) graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
graph(::Input, x) = x graph(::Input, x) = x

View File

@ -1,4 +1,4 @@
export Conv2D, MaxPool export Conv2D, MaxPool, Reshape
type Conv2D <: Model type Conv2D <: Model
filter::Param{Array{Float32,4}} # [height, width, inchans, outchans] filter::Param{Array{Float32,4}} # [height, width, inchans, outchans]
@ -29,3 +29,16 @@ shape(c::MaxPool, in::Dims{3}) =
(shape(c, (in[1],in[2]))..., in[3]) (shape(c, (in[1],in[2]))..., in[3])
shape(c::MaxPool, in) = throw(ShapeError(c, in)) shape(c::MaxPool, in) = throw(ShapeError(c, in))
immutable Reshape{N}
dims::Dims{N}
end
Reshape(dims::Integer...) = Reshape(dims)
function shape(r::Reshape, dims)
prod(dims) == prod(r.dims) || throw(ShapeError(r, dims))
return r.dims
end
shape(r::Reshape, ::Void) = r.dims