more TF support
This commit is contained in:
parent
cc1ca4c3c2
commit
9e9c57d49b
|
@ -21,6 +21,7 @@ graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
|
|||
|
||||
function graph(model::Model, args...)
|
||||
g = Flux.graph(model)
|
||||
g ≠ nothing || error("No graph for $model")
|
||||
g = Flow.mapconst(g) do x
|
||||
!isa(x, Flux.ModelInput) ? x :
|
||||
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
|
||||
|
@ -34,9 +35,18 @@ graph(::typeof(*), args...) = *(args...)
|
|||
graph(::typeof(+), args...) = +(args...)
|
||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||
graph(::typeof(relu), x) = nn.relu(x)
|
||||
graph(::typeof(tanh), x) = tanh(x)
|
||||
|
||||
graph(::Input, x) = x
|
||||
|
||||
graph(c::Conv2D, x) =
|
||||
nn.conv2d(x, graph(c.filter), [1,c.stride...,1], "VALID")
|
||||
|
||||
graph(p::MaxPool, x) =
|
||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
||||
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
|
||||
|
||||
# Treat the first dimension as the batch index
|
||||
# TODO: custom data type for this
|
||||
batch(x) = reshape(x, (1,size(x)...))
|
||||
|
|
Loading…
Reference in New Issue