diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 291da771..bacf8652 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -37,6 +37,7 @@ graph(::typeof(fill), x, dims) = Ops.fill(convert(Tensor{Int32}, dims), Tensor(x graph(::typeof(Flux.cast), args...) = TensorFlow.cast(args...) graph(::typeof(solve), A, b) = TensorFlow.matrix_solve(A, b) graph(::typeof(triangular_solve), A, b) = TensorFlow.matrix_triangular_solve(A, b; lower=false) +graph(::typeof(randu), x) = Ops.random_uniform(convert(Tensor{Int32},x);dtype=Float32) for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos, sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj, diff --git a/src/ops.jl b/src/ops.jl index 4d57e565..cac4805c 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -1,4 +1,4 @@ -export reshape, tile, fill, cast, solve, triangular_solve +export reshape, tile, fill, cast, solve, triangular_solve, randu import Base: reshape, fill @@ -8,3 +8,4 @@ fill{T}(x::T, dims::AbstractArray) = fill(x,tuple(dims...)) cast{T}(x::AbstractArray, ::Type{T}) = convert(Array{T},x) solve(A::AbstractArray, b::AbstractArray) = A\b triangular_solve(A::AbstractArray, b::AbstractArray) = A\b +randu(x::AbstractArray) = rand(tuple(x...)) diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 4944cd7f..8ca68c2e 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -58,6 +58,7 @@ end @test tf(@net (x,y) -> solve(x,y))(A,b) ≈ A\b _,A,_ = lu(A) @test tf(@net (x,y) -> triangular_solve(x,y))(A,b) ≈ A\b + @test size(tf(@net x -> randu(x))([2,3])) == (2,3) end end