This commit is contained in:
Mike J Innes 2016-12-15 20:53:15 +00:00
parent 4b64bf11a5
commit 03840d043c
1 changed files with 5 additions and 2 deletions

View File

@ -7,6 +7,7 @@ using TensorFlow: RawTensor
node(x::Tuple) = map(node, x)
node(x::Tensor) = x
node(x::Variable) = x
node(x::Number) = TensorFlow.constant(Float32(x))
graph(::typeof(tuple), args...) = (args...,)
@ -18,10 +19,12 @@ graph(::typeof(hcat), xs...) = concat(1, xs)
graph(::typeof(seq), xs, n) = TensorFlow.unpack(xs, num = n, axis = 1)
graph(::typeof(.+), args...) = +(args...)
for op in (tanh, *, .*, +, -, .-)
@eval graph(::typeof($op), args...) = $op(args...)
for op in (tanh, *, .*, +, -)
@eval graph(::typeof($op), args...) = $op(node(args)...)
end
graph(::typeof(.-), args...) = -(node(args)...)
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))