diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 47113643..291da771 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -36,6 +36,7 @@ graph(::typeof(Flux.tile), args...) = TensorFlow.tile(args...) graph(::typeof(fill), x, dims) = Ops.fill(convert(Tensor{Int32}, dims), Tensor(x)) graph(::typeof(Flux.cast), args...) = TensorFlow.cast(args...) graph(::typeof(solve), A, b) = TensorFlow.matrix_solve(A, b) +graph(::typeof(triangular_solve), A, b) = TensorFlow.matrix_triangular_solve(A, b; lower=false) for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos, sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj, diff --git a/src/ops.jl b/src/ops.jl index 8abc6c41..4d57e565 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -1,4 +1,4 @@ -export reshape, tile, fill, cast, solve +export reshape, tile, fill, cast, solve, triangular_solve import Base: reshape, fill @@ -7,3 +7,4 @@ tile(x::AbstractArray, mult::AbstractArray) = repeat(x,outer=tuple(mult...)) fill{T}(x::T, dims::AbstractArray) = fill(x,tuple(dims...)) cast{T}(x::AbstractArray, ::Type{T}) = convert(Array{T},x) solve(A::AbstractArray, b::AbstractArray) = A\b +triangular_solve(A::AbstractArray, b::AbstractArray) = A\b diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 3d9f56b0..4944cd7f 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -56,6 +56,8 @@ end A = randn(Float32,(5,5)) b = randn(Float32,(5,1)) @test tf(@net (x,y) -> solve(x,y))(A,b) ≈ A\b + _,A,_ = lu(A) + @test tf(@net (x,y) -> triangular_solve(x,y))(A,b) ≈ A\b end end