2016-11-13 20:27:20 +00:00
|
|
|
|
using Base: @get!
|
2017-02-01 06:57:02 +00:00
|
|
|
|
using DataFlow: Constant, constant, Split
|
|
|
|
|
using DataFlow.Interpreter
|
2017-03-27 17:23:22 +00:00
|
|
|
|
using DataFlow.Interpreter: stack
|
2016-12-26 13:42:12 +00:00
|
|
|
|
using Flux: imap
|
2017-03-12 18:34:11 +00:00
|
|
|
|
using TensorFlow: RawTensor, TFException
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
2016-11-02 00:36:13 +00:00
|
|
|
|
# TODO: implement Julia's type promotion rules
|
|
|
|
|
|
2016-11-13 20:27:20 +00:00
|
|
|
|
node(x::Tuple) = map(node, x)
|
2016-11-15 21:09:58 +00:00
|
|
|
|
node(x::Tensor) = x
|
2016-12-15 20:53:15 +00:00
|
|
|
|
node(x::Variable) = x
|
2016-11-13 20:27:20 +00:00
|
|
|
|
node(x::Number) = TensorFlow.constant(Float32(x))
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
2016-11-15 21:09:58 +00:00
|
|
|
|
graph(::typeof(tuple), args...) = (args...,)
|
|
|
|
|
graph(s::Split, t::Tuple) = t[s.n]
|
2017-05-01 16:44:20 +00:00
|
|
|
|
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
2017-05-01 15:28:39 +00:00
|
|
|
|
graph(::typeof(identity), x) = TensorFlow.identity(x)
|
2016-10-26 13:25:10 +00:00
|
|
|
|
graph(::typeof(softmax), x) = nn.softmax(x)
|
|
|
|
|
graph(::typeof(relu), x) = nn.relu(x)
|
2016-10-28 14:13:43 +00:00
|
|
|
|
graph(::typeof(σ), x) = nn.sigmoid(x)
|
2016-11-17 11:28:15 +00:00
|
|
|
|
graph(::typeof(hcat), xs...) = concat(1, xs)
|
|
|
|
|
graph(::typeof(seq), xs, n) = TensorFlow.unpack(xs, num = n, axis = 1)
|
2016-11-15 21:09:58 +00:00
|
|
|
|
|
2017-03-27 17:23:22 +00:00
|
|
|
|
for op in (tanh, *, .*, .+)
|
2017-01-27 11:33:22 +00:00
|
|
|
|
@eval graph(::typeof($op), args...) = $op(args...)
|
2016-11-15 21:09:58 +00:00
|
|
|
|
end
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
2017-01-27 11:33:22 +00:00
|
|
|
|
graph(::typeof(.-), args...) = -(args...)
|
2016-12-15 20:53:15 +00:00
|
|
|
|
|
2016-10-26 13:25:10 +00:00
|
|
|
|
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
|
|
|
|
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
|
|
|
|
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
|
|
|
|
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
|
|
|
|
|
|
|
|
|
|
graph(::Input, x) = x
|
|
|
|
|
|
|
|
|
|
graph(p::MaxPool, x) =
|
|
|
|
|
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
|
|
|
|
|
|
|
|
|
graph(op::Op, xs...) = op.f(xs...)
|
|
|
|
|
|
2016-12-20 16:31:20 +00:00
|
|
|
|
function graph(ctx::Context, model, args...)
|
2016-12-21 13:05:18 +00:00
|
|
|
|
node = graph(model, interpv(ctx, args)...)
|
2017-03-14 15:21:18 +00:00
|
|
|
|
node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx))
|
2016-12-20 16:31:20 +00:00
|
|
|
|
return node
|
|
|
|
|
end
|
|
|
|
|
|
2016-11-13 20:27:20 +00:00
|
|
|
|
interp(ctx, c::Conv2D, x) =
|
2016-12-21 13:05:18 +00:00
|
|
|
|
nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
2016-11-13 20:27:20 +00:00
|
|
|
|
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
|
|
|
|
haskey(ctx[:params], p.value) ?
|
|
|
|
|
ctx[:params][p.value] :
|
2017-01-27 10:32:52 +00:00
|
|
|
|
(ctx[:params][p.value] = Variable(convertel(Float32, p.value.x)))
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
2017-03-03 14:06:51 +00:00
|
|
|
|
interp(ctx, p::Constant) = p.value
|
2016-12-13 15:46:34 +00:00
|
|
|
|
|
2016-11-13 20:27:20 +00:00
|
|
|
|
function interp(ctx, model, args...)
|
2016-10-28 14:13:58 +00:00
|
|
|
|
g = Flux.graph(model)
|
2016-12-20 16:31:20 +00:00
|
|
|
|
g == nothing && return graph(ctx, model, args...)
|
2016-10-31 12:38:18 +00:00
|
|
|
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
2016-12-21 13:05:18 +00:00
|
|
|
|
interpret(ctx, g, interpv(ctx, args)...)
|
2016-10-28 15:06:56 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function tograph(model, args...)
|
2017-01-30 17:08:38 +00:00
|
|
|
|
ctx = Context(mux(iline, ilambda, imap, interp),
|
2016-12-20 16:31:20 +00:00
|
|
|
|
params = ObjectIdDict(), stacks = Dict())
|
2016-11-13 20:27:20 +00:00
|
|
|
|
out = interp(ctx, model, map(constant, args)...)
|
2016-12-20 17:32:33 +00:00
|
|
|
|
return ctx[:params], ctx[:stacks], out
|
2016-10-28 14:13:58 +00:00
|
|
|
|
end
|
|
|
|
|
|
2017-01-26 18:32:59 +00:00
|
|
|
|
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
|
|
|
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
2017-03-12 18:34:11 +00:00
|
|
|
|
|
|
|
|
|
# Error Handling
|
|
|
|
|
|
|
|
|
|
using Juno
|
|
|
|
|
using MacroTools: @q
|
|
|
|
|
using DataFlow.Interpreter: Exception, totrace
|
|
|
|
|
Juno.errmsg(e::TFException) = string(e.status)
|
|
|
|
|
|
|
|
|
|
function errnode(e::TFException)
|
|
|
|
|
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
|
|
|
|
m == nothing && return
|
|
|
|
|
m.captures[1]
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
errnode(e) = nothing
|
|
|
|
|
|
|
|
|
|
macro tferr(stk, ex)
|
|
|
|
|
@q try
|
|
|
|
|
$(esc(ex))
|
|
|
|
|
catch e
|
|
|
|
|
(node = errnode(e)) != nothing || rethrow()
|
|
|
|
|
stk = $(esc(stk))
|
|
|
|
|
haskey(stk, node) || rethrow()
|
|
|
|
|
throw(Exception(e, totrace(stk[node])))
|
|
|
|
|
end
|
|
|
|
|
end
|