fix ops
This commit is contained in:
parent
4b64bf11a5
commit
03840d043c
|
@ -7,6 +7,7 @@ using TensorFlow: RawTensor
|
|||
|
||||
node(x::Tuple) = map(node, x)
|
||||
node(x::Tensor) = x
|
||||
node(x::Variable) = x
|
||||
node(x::Number) = TensorFlow.constant(Float32(x))
|
||||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
|
@ -18,10 +19,12 @@ graph(::typeof(hcat), xs...) = concat(1, xs)
|
|||
graph(::typeof(seq), xs, n) = TensorFlow.unpack(xs, num = n, axis = 1)
|
||||
graph(::typeof(.+), args...) = +(args...)
|
||||
|
||||
for op in (tanh, *, .*, +, -, .-)
|
||||
@eval graph(::typeof($op), args...) = $op(args...)
|
||||
for op in (tanh, *, .*, +, -)
|
||||
@eval graph(::typeof($op), args...) = $op(node(args)...)
|
||||
end
|
||||
|
||||
graph(::typeof(.-), args...) = -(node(args)...)
|
||||
|
||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
||||
|
|
Loading…
Reference in New Issue