test native tensor support

This commit is contained in:
Mike J Innes 2017-01-27 00:02:59 +05:30
parent b79e536c13
commit 42fabadd11
2 changed files with 16 additions and 4 deletions

View File

@ -67,6 +67,6 @@ function tograph(model, args...)
return ctx[:params], ctx[:stacks], out
end
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))

View File

@ -1,7 +1,19 @@
xs = rand(20)
d = Affine(20, 10)
dt = tf(d)
let dt = tf(d)
@test d(xs) dt(xs)
end
@test d(xs) dt(xs)
# TensorFlow native integration
using TensorFlow
let
sess = TensorFlow.Session()
X = placeholder(Float32)
Y = Tensor(d, X)
run(sess, initialize_all_variables())
@test run(sess, Y, Dict(X=>xs')) d(xs)'
end