diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index edc00491..70fc413f 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -18,6 +18,7 @@ node(x::Tuple) = map(node, x) node(x::mx.SymbolicNode) = x graph(::typeof(tuple), args...) = (args...,) +graph(::typeof(getindex), t::Tuple, n::Integer) = t[n] graph(::typeof(.+), args...) = mx.broadcast_plus(args...) graph(::typeof(.*), args...) = mx.broadcast_mul(args...) graph(::typeof(.-), args...) = mx.broadcast_sub(args...) diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 3c81e6e7..ef9fdaa7 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -14,6 +14,7 @@ node(x::Number) = TensorFlow.constant(Float32(x)) graph(::typeof(tuple), args...) = (args...,) graph(s::Split, t::Tuple) = t[s.n] +graph(::typeof(getindex), t::Tuple, n::Integer) = t[n] graph(::typeof(identity), x) = TensorFlow.identity(x) graph(::typeof(softmax), x) = nn.softmax(x) graph(::typeof(relu), x) = nn.relu(x) diff --git a/test/backend/mxnet.jl b/test/backend/mxnet.jl index b9527e2b..a70ec7dd 100644 --- a/test/backend/mxnet.jl +++ b/test/backend/mxnet.jl @@ -13,6 +13,11 @@ m = Multi(20, 15) mm = mxnet(m) @test all(isapprox.(mm(xs, ys), m(xs, ys))) +@testset "Tuple I/O" begin + @test mxnet(@net x -> (x,))([1,2,3]) == ([1,2,3],) + @test mxnet(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18] +end + @testset "Recurrence" begin seq = batchone(Seq(rand(10) for i = 1:3)) r = unroll(Recurrent(10, 5), 3) diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 54f5560b..6d7780f9 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -13,6 +13,11 @@ m = Multi(20, 15) mm = tf(m) @test all(isapprox.(mm(xs, ys), m(xs, ys))) +@testset "Tuple I/O" begin + @test tf(@net x -> (identity(x),))([1,2,3]) == ([1,2,3],) + @test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18] +end + @testset "Tensor interface" begin sess = TensorFlow.Session() X = placeholder(Float32)