test native tensor support
This commit is contained in:
parent
b79e536c13
commit
42fabadd11
|
@ -67,6 +67,6 @@ function tograph(model, args...)
|
|||
return ctx[:params], ctx[:stacks], out
|
||||
end
|
||||
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
|
|
|
@ -1,7 +1,19 @@
|
|||
xs = rand(20)
|
||||
|
||||
d = Affine(20, 10)
|
||||
|
||||
dt = tf(d)
|
||||
let dt = tf(d)
|
||||
@test d(xs) ≈ dt(xs)
|
||||
end
|
||||
|
||||
@test d(xs) ≈ dt(xs)
|
||||
# TensorFlow native integration
|
||||
|
||||
using TensorFlow
|
||||
|
||||
let
|
||||
sess = TensorFlow.Session()
|
||||
X = placeholder(Float32)
|
||||
Y = Tensor(d, X)
|
||||
run(sess, initialize_all_variables())
|
||||
|
||||
@test run(sess, Y, Dict(X=>xs')) ≈ d(xs)'
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue