diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index bacf8652..7707cf8f 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -38,6 +38,7 @@ graph(::typeof(Flux.cast), args...) = TensorFlow.cast(args...) graph(::typeof(solve), A, b) = TensorFlow.matrix_solve(A, b) graph(::typeof(triangular_solve), A, b) = TensorFlow.matrix_triangular_solve(A, b; lower=false) graph(::typeof(randu), x) = Ops.random_uniform(convert(Tensor{Int32},x);dtype=Float32) +graph(::typeof(randn), x) = TensorFlow.random_normal(convert(Tensor{Int32},x);dtype=Float32) for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos, sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj, diff --git a/src/ops.jl b/src/ops.jl index cac4805c..31292501 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -1,6 +1,6 @@ -export reshape, tile, fill, cast, solve, triangular_solve, randu +export reshape, tile, fill, cast, solve, triangular_solve, randu, randn -import Base: reshape, fill +import Base: reshape, fill, randn reshape(x::AbstractArray, dims::AbstractArray) = reshape(x,tuple(dims...)) tile(x::AbstractArray, mult::AbstractArray) = repeat(x,outer=tuple(mult...)) @@ -9,3 +9,4 @@ cast{T}(x::AbstractArray, ::Type{T}) = convert(Array{T},x) solve(A::AbstractArray, b::AbstractArray) = A\b triangular_solve(A::AbstractArray, b::AbstractArray) = A\b randu(x::AbstractArray) = rand(tuple(x...)) +randn(x::AbstractArray) = randn(tuple(x...)) diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 8ca68c2e..77fdab19 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -59,6 +59,7 @@ end _,A,_ = lu(A) @test tf(@net (x,y) -> triangular_solve(x,y))(A,b) ≈ A\b @test size(tf(@net x -> randu(x))([2,3])) == (2,3) + @test size(tf(@net x -> randn(x))([2,3])) == (2,3) end end