graph support
This commit is contained in:
parent
d42130b8cd
commit
27aa2bf8d4
|
@ -12,6 +12,7 @@ graph(::typeof(+), args...) = +(args...)
|
|||
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||
graph(::typeof(relu), x) = nn.relu(x)
|
||||
graph(::typeof(tanh), x) = tanh(x)
|
||||
graph(::typeof(σ), x) = nn.sigmoid(x)
|
||||
|
||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||
|
@ -26,6 +27,8 @@ graph(c::Conv2D, x) =
|
|||
graph(p::MaxPool, x) =
|
||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
||||
|
||||
graph(::Flow.Group, xs...) = (xs...,)
|
||||
|
||||
type Op
|
||||
f
|
||||
shape
|
||||
|
|
Loading…
Reference in New Issue