diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 650a29e6..8f5b47bc 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -1,6 +1,7 @@ using Base: @get! using DataFlow: Constant, constant, Split using DataFlow.Interpreter +using DataFlow.Interpreter: stack using Flux: imap using TensorFlow: RawTensor, TFException @@ -19,7 +20,7 @@ graph(::typeof(σ), x) = nn.sigmoid(x) graph(::typeof(hcat), xs...) = concat(1, xs) graph(::typeof(seq), xs, n) = TensorFlow.unpack(xs, num = n, axis = 1) -for op in (tanh, *, .*, .+, .-) +for op in (tanh, *, .*, .+) @eval graph(::typeof($op), args...) = $op(args...) end