more TF support

This commit is contained in:
Mike J Innes 2016-10-04 22:23:37 +01:00
parent cc1ca4c3c2
commit 9e9c57d49b
1 changed files with 10 additions and 0 deletions

View File

@ -21,6 +21,7 @@ graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
function graph(model::Model, args...)
g = Flux.graph(model)
g nothing || error("No graph for $model")
g = Flow.mapconst(g) do x
!isa(x, Flux.ModelInput) ? x :
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
@ -34,9 +35,18 @@ graph(::typeof(*), args...) = *(args...)
graph(::typeof(+), args...) = +(args...)
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)
graph(::typeof(tanh), x) = tanh(x)
graph(::Input, x) = x
graph(c::Conv2D, x) =
nn.conv2d(x, graph(c.filter), [1,c.stride...,1], "VALID")
graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
# Treat the first dimension as the batch index
# TODO: custom data type for this
batch(x) = reshape(x, (1,size(x)...))