basic tf backend

This commit is contained in:
Mike J Innes 2016-09-27 02:16:49 +01:00
parent df38a89d9a
commit b662df6ce1
3 changed files with 66 additions and 5 deletions

View File

@ -1,2 +1,2 @@
julia 0.5-
MXNet
TensorFlow

View File

@ -1,7 +1,9 @@
# TODO: load backends lazily
include("mxnet/mxnet.jl")
# include("mxnet/mxnet.jl")
# using .MX
# export mxnet
using .MX
export mxnet
include("tensorflow/tensorflow.jl")
using .TF
export tf

View File

@ -0,0 +1,59 @@
module TF
using ..Flux, Flow, TensorFlow
# Workaround for tensor display bug
using Juno
Media.render(::Juno.Clipboard, ::Tensor) = "Tensor()"
cvalue(x) = x
cvalue(c::Constant) = c.value
cvalue(v::Vertex) = cvalue(value(v))
graph(x::Tensor) = x
matrixify(xs) = xs
matrixify(xs::Vector) = xs[:,1:1]
# TODO: detect variable reuse
graph{T<:AArray}(p::Flux.Param{T}) = Variable(matrixify(p.x))
function graph(model::Model, args...)
g = Flux.graph(model)
g = Flow.mapconst(g) do x
!isa(x, Flux.ModelInput) ? x :
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
end
postwalk(g) do v
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
end |> value
end
graph(::typeof(*), args...) = *(args...)
graph(::typeof(+), args...) = +(args...)
type Model
session::Session
inputs::Vector{Tensor}
graph::Tensor
end
Media.render(::Juno.Clipboard, ::Model) = "Flux.TF.Model()"
function tf(model)
sess = Session()
input = placeholder(Float64)
g = graph(model, input)
run(sess, initialize_all_variables())
Model(sess, [input], g)
end
function (m::Model)(args...)
@assert length(args) == length(m.inputs)
run(m.session, m.graph, Dict(zip(m.inputs, args)))
end
# m = Flux.Dense(784, 10)
# t = tf(m)
# t(randn(784,1))
end