Flux.jl/src/backend/tensorflow/graph.jl

79 lines
2.4 KiB
Julia
Raw Normal View History

using Base: @get!
2016-11-17 11:28:24 +00:00
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, interplambda, interpconst
using Flux: interpmap
using TensorFlow: RawTensor
2016-10-26 13:25:10 +00:00
2016-11-02 00:36:13 +00:00
# TODO: implement Julia's type promotion rules
node(x::Tuple) = map(node, x)
2016-11-15 21:09:58 +00:00
node(x::Tensor) = x
2016-12-15 20:53:15 +00:00
node(x::Variable) = x
node(x::Number) = TensorFlow.constant(Float32(x))
2016-10-26 13:25:10 +00:00
2016-11-15 21:09:58 +00:00
graph(::typeof(tuple), args...) = (args...,)
graph(s::Split, t::Tuple) = t[s.n]
2016-10-26 13:25:10 +00:00
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)
2016-10-28 14:13:43 +00:00
graph(::typeof(σ), x) = nn.sigmoid(x)
2016-11-17 11:28:15 +00:00
graph(::typeof(hcat), xs...) = concat(1, xs)
graph(::typeof(seq), xs, n) = TensorFlow.unpack(xs, num = n, axis = 1)
2016-11-15 21:09:58 +00:00
2016-12-15 20:53:15 +00:00
for op in (tanh, *, .*, +, -)
@eval graph(::typeof($op), args...) = $op(node(args)...)
2016-11-15 21:09:58 +00:00
end
2016-10-26 13:25:10 +00:00
2016-12-15 20:53:15 +00:00
graph(::typeof(.-), args...) = -(node(args)...)
2016-10-26 13:25:10 +00:00
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
graph(::Input, x) = x
graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
graph(op::Op, xs...) = op.f(xs...)
interp(ctx, c::Conv2D, x) =
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
2016-10-26 13:25:10 +00:00
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
haskey(ctx[:params], p.value) ?
ctx[:params][p.value] :
(ctx[:params][p.value] = Variable(p.value.x))
2016-10-26 13:25:10 +00:00
2016-12-13 15:46:34 +00:00
interp(ctx, p::Constant) = p.value
function interp(ctx, model, args...)
2016-10-28 14:13:58 +00:00
g = Flux.graph(model)
g == nothing && return graph(model, interpret(ctx, args)...)
2016-10-31 12:38:18 +00:00
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, interpret(ctx, args)...)
2016-10-28 15:06:56 +00:00
end
function tograph(model, args...)
2016-12-13 15:46:34 +00:00
ctx = Context(interplambda(interptuple(interpmap(interp))), params = ObjectIdDict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out
2016-10-28 14:13:58 +00:00
end
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[2]
2016-10-26 13:25:10 +00:00
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
2016-10-30 12:10:44 +00:00
function makesession(model, n)
sess = Session(Graph())
inputs = [placeholder(Float32) for _ = 1:n]
params, output = tograph(model, inputs...)
run(sess, initialize_all_variables())
sess, params, inputs, output
end
function storeparams!(sess, params)
for (p, t) in params
p.x = run(sess, t)
end
end