diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 4235d047..562252d9 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -33,6 +33,7 @@ graph(::typeof(size), x) = TensorFlow.size(x) graph(::typeof(chol), args...) = TensorFlow.transpose(TensorFlow.cholesky(args...)) graph(::typeof(reshape), x, dims) = TensorFlow.reshape(x,convert(Tensor{Int32},dims)) graph(::typeof(Flux.tile), args...) = TensorFlow.tile(args...) +graph(::typeof(fill), x, dims) = Ops.fill(convert(Tensor{Int32}, dims), Tensor(x)) for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos, sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj, diff --git a/src/ops.jl b/src/ops.jl index 7c25ba56..b76ef203 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -1,6 +1,7 @@ -export reshape, tile +export reshape, tile, fill -import Base: reshape +import Base: reshape, fill reshape(x::AbstractArray, dims::AbstractArray) = reshape(x,tuple(dims...)) tile(x::AbstractArray, mult::AbstractArray) = repeat(x,outer=tuple(mult...)) +fill{T}(x::T, dims::AbstractArray) = fill(x,tuple(dims...)) diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 0fdd5421..531bf38b 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -51,6 +51,7 @@ end @test transpose(tf(@net (x,y) -> reshape(x,y))(transpose(A),[2,9])) ≈ reshape(A,(9,2)) # Note: TF is row major and julia is not A = randn(Float32,(4,3,1)) @test tf(@net (x,y) -> Flux.tile(x,y))(A,[1,1,3]) ≈ repeat(A,outer=(1,1,3)) + @test tf(@net (x,y) -> fill(x,y))(3.2,[3,2]) ≈ convert(Array{Float32},3.2*ones(3,2)) end end