diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 8f5b47bc..3c81e6e7 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -14,6 +14,7 @@ node(x::Number) = TensorFlow.constant(Float32(x)) graph(::typeof(tuple), args...) = (args...,) graph(s::Split, t::Tuple) = t[s.n] +graph(::typeof(identity), x) = TensorFlow.identity(x) graph(::typeof(softmax), x) = nn.softmax(x) graph(::typeof(relu), x) = nn.relu(x) graph(::typeof(σ), x) = nn.sigmoid(x)