diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 2b8f24bd..47113643 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -35,6 +35,7 @@ graph(::typeof(reshape), x, dims) = TensorFlow.reshape(x,convert(Tensor{Int32},d graph(::typeof(Flux.tile), args...) = TensorFlow.tile(args...) graph(::typeof(fill), x, dims) = Ops.fill(convert(Tensor{Int32}, dims), Tensor(x)) graph(::typeof(Flux.cast), args...) = TensorFlow.cast(args...) +graph(::typeof(solve), A, b) = TensorFlow.matrix_solve(A, b) for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos, sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj, diff --git a/src/ops.jl b/src/ops.jl index a4c6c4e7..8abc6c41 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -1,4 +1,4 @@ -export reshape, tile, fill, cast +export reshape, tile, fill, cast, solve import Base: reshape, fill @@ -6,3 +6,4 @@ reshape(x::AbstractArray, dims::AbstractArray) = reshape(x,tuple(dims...)) tile(x::AbstractArray, mult::AbstractArray) = repeat(x,outer=tuple(mult...)) fill{T}(x::T, dims::AbstractArray) = fill(x,tuple(dims...)) cast{T}(x::AbstractArray, ::Type{T}) = convert(Array{T},x) +solve(A::AbstractArray, b::AbstractArray) = A\b diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 7d09eccd..3d9f56b0 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -53,6 +53,9 @@ end @test tf(@net (x,y) -> Flux.tile(x,y))(A,[1,1,3]) ≈ repeat(A,outer=(1,1,3)) @test tf(@net (x,y) -> fill(x,y))(3.2,[3,2]) ≈ convert(Array{Float32},3.2*ones(3,2)) @test typeof(tf(@net x -> Flux.cast(x,Int32))(A)) == Array{Int32,3} + A = randn(Float32,(5,5)) + b = randn(Float32,(5,1)) + @test tf(@net (x,y) -> solve(x,y))(A,b) ≈ A\b end end