diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 60e3f078..febee8fb 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -13,6 +13,8 @@ graph(s::Split, t::Tuple) = t[s.n] graph(::typeof(softmax), x) = nn.softmax(x) graph(::typeof(relu), x) = nn.relu(x) graph(::typeof(σ), x) = nn.sigmoid(x) +graph(::typeof(hcat), xs...) = concat(1, xs) +graph(::typeof(seq), xs, n) = TensorFlow.unpack(xs, num = n, axis = 1) graph(::typeof(.+), args...) = +(args...) for op in (tanh, *, .*, +, -, .-)