support maps over tuples

This commit is contained in:
Mike J Innes 2017-05-30 17:23:34 +01:00
parent d788cc8c54
commit 790a58be1d
6 changed files with 16 additions and 0 deletions

View File

@ -40,6 +40,8 @@ graph(::typeof(softmax), xs) =
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
graph(::typeof(vcat), a...) = graph(cat, 1, a...)
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
graph(::Input, x) = x
graph(ctx::Context, d::Affine, x) =

View File

@ -39,6 +39,8 @@ end
graph(::typeof(.-), args...) = -(args...)
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))

View File

@ -55,3 +55,11 @@ function test_stacktrace(bk)
@test e.trace[2].func == :TLP
end
end
function test_anon(bk)
@testset "Closures" begin
x, y = rand(3), rand(5)
model = bk(@net xs -> map(x -> x .* x, xs))
@test all(model((x, y)) .≈ (x.*x, y.*y))
end
end

View File

@ -13,6 +13,7 @@ test_tupleio(mxnet)
test_recurrence(mxnet)
test_stacktrace(mxnet)
test_back(mxnet)
test_anon(mxnet)
@testset "Native interface" begin
f = mx.FeedForward(Chain(d, softmax))

View File

@ -12,6 +12,7 @@ dt = tf(d)
test_tupleio(tf)
test_recurrence(tf)
test_stacktrace(tf)
test_anon(tf)
@testset "Tensor interface" begin
sess = TensorFlow.Session()

View File

@ -18,6 +18,8 @@ Flux.infer(d, (1, 10))
# @test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
# end
test_anon(identity)
let a1 = Affine(10, 20), a2 = Affine(20, 15)
tlp = TLP(a1, a2)
@test tlp(xs) softmax(a2(σ(a1(xs))))