tuple i/o tests

This commit is contained in:
Mike J Innes 2017-05-01 17:44:20 +01:00
parent 2467ca4187
commit 3998be2244
4 changed files with 12 additions and 0 deletions

View File

@ -18,6 +18,7 @@ node(x::Tuple) = map(node, x)
node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
graph(::typeof(.-), args...) = mx.broadcast_sub(args...)

View File

@ -14,6 +14,7 @@ node(x::Number) = TensorFlow.constant(Float32(x))
graph(::typeof(tuple), args...) = (args...,)
graph(s::Split, t::Tuple) = t[s.n]
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
graph(::typeof(identity), x) = TensorFlow.identity(x)
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)

View File

@ -13,6 +13,11 @@ m = Multi(20, 15)
mm = mxnet(m)
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
@testset "Tuple I/O" begin
@test mxnet(@net x -> (x,))([1,2,3]) == ([1,2,3],)
@test mxnet(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
end
@testset "Recurrence" begin
seq = batchone(Seq(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3)

View File

@ -13,6 +13,11 @@ m = Multi(20, 15)
mm = tf(m)
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
@testset "Tuple I/O" begin
@test tf(@net x -> (identity(x),))([1,2,3]) == ([1,2,3],)
@test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
end
@testset "Tensor interface" begin
sess = TensorFlow.Session()
X = placeholder(Float32)