graph support
This commit is contained in:
parent
d42130b8cd
commit
27aa2bf8d4
@ -12,6 +12,7 @@ graph(::typeof(+), args...) = +(args...)
|
|||||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||||
graph(::typeof(relu), x) = nn.relu(x)
|
graph(::typeof(relu), x) = nn.relu(x)
|
||||||
graph(::typeof(tanh), x) = tanh(x)
|
graph(::typeof(tanh), x) = tanh(x)
|
||||||
|
graph(::typeof(σ), x) = nn.sigmoid(x)
|
||||||
|
|
||||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||||
@ -26,6 +27,8 @@ graph(c::Conv2D, x) =
|
|||||||
graph(p::MaxPool, x) =
|
graph(p::MaxPool, x) =
|
||||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
||||||
|
|
||||||
|
graph(::Flow.Group, xs...) = (xs...,)
|
||||||
|
|
||||||
type Op
|
type Op
|
||||||
f
|
f
|
||||||
shape
|
shape
|
||||||
|
Loading…
Reference in New Issue
Block a user