diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 4f6a494b..21080914 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -40,6 +40,8 @@ graph(::typeof(softmax), xs) = graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim) graph(::typeof(vcat), a...) = graph(cat, 1, a...) +graph(::typeof(map), f, xss::Tuple...) = map(f, xss...) + graph(::Input, x) = x graph(ctx::Context, d::Affine, x) = diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 20c3c472..e4cae4c0 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -39,6 +39,8 @@ end graph(::typeof(.-), args...) = -(args...) +graph(::typeof(map), f, xss::Tuple...) = map(f, xss...) + # reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79 batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1])) graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)])) diff --git a/test/backend/common.jl b/test/backend/common.jl index 1ed0122c..de4930d8 100644 --- a/test/backend/common.jl +++ b/test/backend/common.jl @@ -55,3 +55,11 @@ function test_stacktrace(bk) @test e.trace[2].func == :TLP end end + +function test_anon(bk) + @testset "Closures" begin + x, y = rand(3), rand(5) + model = bk(@net xs -> map(x -> x .* x, xs)) + @test all(model((x, y)) .≈ (x.*x, y.*y)) + end +end diff --git a/test/backend/mxnet.jl b/test/backend/mxnet.jl index 29e51dd2..c3e39806 100644 --- a/test/backend/mxnet.jl +++ b/test/backend/mxnet.jl @@ -13,6 +13,7 @@ test_tupleio(mxnet) test_recurrence(mxnet) test_stacktrace(mxnet) test_back(mxnet) +test_anon(mxnet) @testset "Native interface" begin f = mx.FeedForward(Chain(d, softmax)) diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index ab9d4af7..b727f5e1 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -12,6 +12,7 @@ dt = tf(d) test_tupleio(tf) test_recurrence(tf) test_stacktrace(tf) +test_anon(tf) @testset "Tensor interface" begin sess = TensorFlow.Session() diff --git a/test/basic.jl b/test/basic.jl index f7d35bd4..37176b89 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -18,6 +18,8 @@ Flux.infer(d, (1, 10)) # @test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param) # end +test_anon(identity) + let a1 = Affine(10, 20), a2 = Affine(20, 15) tlp = TLP(a1, a2) @test tlp(xs) ≈ softmax(a2(σ(a1(xs))))