diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 49daf657..63c4f4e2 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -31,7 +31,7 @@ graph(::typeof(svd), x) = svd(x) for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos, sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj, - inv) + inv, det) @eval graph(::typeof($op), args...) = $op(args...) end diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 6fdef6cd..61f700f6 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -41,6 +41,12 @@ u,s,v = m(A) m = tf(f) @test maximum(abs.(m(A)-inv(A))) < error_margin +#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# det +@net f(x) = det(x) +m = tf(f) +@test maximum(abs.(m(A)-det(A))) < error_margin + #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% end