tuple i/o tests
This commit is contained in:
parent
2467ca4187
commit
3998be2244
@ -18,6 +18,7 @@ node(x::Tuple) = map(node, x)
|
||||
node(x::mx.SymbolicNode) = x
|
||||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
||||
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
|
||||
graph(::typeof(.-), args...) = mx.broadcast_sub(args...)
|
||||
|
@ -14,6 +14,7 @@ node(x::Number) = TensorFlow.constant(Float32(x))
|
||||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(s::Split, t::Tuple) = t[s.n]
|
||||
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
||||
graph(::typeof(identity), x) = TensorFlow.identity(x)
|
||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||
graph(::typeof(relu), x) = nn.relu(x)
|
||||
|
@ -13,6 +13,11 @@ m = Multi(20, 15)
|
||||
mm = mxnet(m)
|
||||
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||
|
||||
@testset "Tuple I/O" begin
|
||||
@test mxnet(@net x -> (x,))([1,2,3]) == ([1,2,3],)
|
||||
@test mxnet(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
|
||||
end
|
||||
|
||||
@testset "Recurrence" begin
|
||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
|
@ -13,6 +13,11 @@ m = Multi(20, 15)
|
||||
mm = tf(m)
|
||||
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||
|
||||
@testset "Tuple I/O" begin
|
||||
@test tf(@net x -> (identity(x),))([1,2,3]) == ([1,2,3],)
|
||||
@test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
|
||||
end
|
||||
|
||||
@testset "Tensor interface" begin
|
||||
sess = TensorFlow.Session()
|
||||
X = placeholder(Float32)
|
||||
|
Loading…
Reference in New Issue
Block a user