graph support

This commit is contained in:
Mike J Innes 2016-10-28 15:13:43 +01:00
parent d42130b8cd
commit 27aa2bf8d4
1 changed files with 3 additions and 0 deletions

View File

@ -12,6 +12,7 @@ graph(::typeof(+), args...) = +(args...)
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)
graph(::typeof(tanh), x) = tanh(x)
graph(::typeof(σ), x) = nn.sigmoid(x)
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
@ -26,6 +27,8 @@ graph(c::Conv2D, x) =
graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
graph(::Flow.Group, xs...) = (xs...,)
type Op
f
shape