2017-01-29 10:39:30 +00:00
|
|
|
|
function nodename(s::mx.SymbolicNode)
|
2017-01-28 18:06:58 +00:00
|
|
|
|
name = Ref{mx.char_p}(0)
|
|
|
|
|
success = Ref(0)
|
|
|
|
|
mx.@mxcall(:MXSymbolGetName, (mx.MX_handle, Ref{mx.char_p}, Ref{Int}), s.handle.value, name, success)
|
|
|
|
|
@assert success[] != -1
|
|
|
|
|
return Symbol(unsafe_wrap(String, name[]))
|
|
|
|
|
end
|
|
|
|
|
|
2017-01-28 17:02:49 +00:00
|
|
|
|
using Base: @get!
|
|
|
|
|
using DataFlow: Constant, constant, Context, interpret, Split,
|
2017-01-30 17:21:49 +00:00
|
|
|
|
interpv, ituple, ilambda, iconst, iline, iargs, stack, mux
|
2017-01-28 17:02:49 +00:00
|
|
|
|
using Flux: imap
|
|
|
|
|
|
|
|
|
|
# TODO: implement Julia's type promotion rules
|
|
|
|
|
|
|
|
|
|
node(x::Tuple) = map(node, x)
|
|
|
|
|
node(x::mx.SymbolicNode) = x
|
|
|
|
|
# node(x::Number) = TensorFlow.constant(Float32(x))
|
|
|
|
|
|
|
|
|
|
graph(::typeof(tuple), args...) = (args...,)
|
|
|
|
|
graph(s::Split, t::Tuple) = t[s.n]
|
2017-01-28 17:37:02 +00:00
|
|
|
|
graph(::typeof(*), args...) = mx.dot(args...)
|
2017-01-28 17:02:49 +00:00
|
|
|
|
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
|
|
|
|
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
|
|
|
|
graph(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
|
|
|
|
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type=:tanh)
|
|
|
|
|
graph(::typeof(flatten), x) = mx.Flatten(data = x)
|
|
|
|
|
|
|
|
|
|
graph(::typeof(softmax), xs) =
|
|
|
|
|
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
|
|
|
|
|
|
|
|
|
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
|
|
|
|
|
graph(::typeof(vcat), a...) = node(cat, 1, a...)
|
|
|
|
|
|
|
|
|
|
graph(::Input, x) = x
|
|
|
|
|
|
|
|
|
|
# graph(vars, c::Conv, x) =
|
|
|
|
|
# mx.Convolution(data = x,
|
|
|
|
|
# kernel = c.size,
|
|
|
|
|
# num_filter = c.features,
|
|
|
|
|
# stride = c.stride)
|
|
|
|
|
#
|
|
|
|
|
# graph(vars, p::MaxPool, x) =
|
|
|
|
|
# mx.Pooling(data = x,
|
|
|
|
|
# pool_type = :max,
|
|
|
|
|
# kernel = p.size,
|
|
|
|
|
# stride = p.stride)
|
|
|
|
|
#
|
|
|
|
|
# graph(vars, d::Dense, x) =
|
|
|
|
|
# mx.FullyConnected(data = x,
|
|
|
|
|
# num_hidden = size(d.W.x, 1),
|
|
|
|
|
# weight = graph(vars, d.W),
|
|
|
|
|
# bias = graph(vars, d.b))
|
|
|
|
|
|
|
|
|
|
function interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}})
|
|
|
|
|
id = gensym()
|
|
|
|
|
ctx[:params][id] = p.value.x
|
|
|
|
|
return mx.Variable(id)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
interp(ctx, p::Constant) = node(p.value)
|
|
|
|
|
|
|
|
|
|
function graph(ctx::Context, model, args...)
|
2017-01-30 17:21:49 +00:00
|
|
|
|
node = graph(model, args...)
|
2017-01-29 10:39:30 +00:00
|
|
|
|
isa(node, mx.SymbolicNode) && (ctx[:stacks][nodename(node)] = stack(ctx))
|
2017-01-28 17:02:49 +00:00
|
|
|
|
return node
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function interp(ctx, model, args...)
|
|
|
|
|
g = Flux.graph(model)
|
|
|
|
|
g == nothing && return graph(ctx, model, args...)
|
|
|
|
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
2017-01-30 17:21:49 +00:00
|
|
|
|
interpret(ctx, g, args...)
|
2017-01-28 17:02:49 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function tograph(model, args...)
|
2017-01-30 17:21:49 +00:00
|
|
|
|
ctx = Context(mux(iline, ilambda, imap, iargs, interp),
|
2017-01-28 17:02:49 +00:00
|
|
|
|
params = Dict(), stacks = Dict())
|
|
|
|
|
out = interp(ctx, model, map(constant, args)...)
|
|
|
|
|
return ctx[:params], ctx[:stacks], out
|
|
|
|
|
end
|
2017-01-29 11:59:37 +00:00
|
|
|
|
|
|
|
|
|
# Error Handling
|
|
|
|
|
|
2017-01-29 18:05:03 +00:00
|
|
|
|
using Juno
|
|
|
|
|
Juno.errmsg(e::mx.MXError) = e.msg
|
|
|
|
|
|
2017-01-29 11:59:37 +00:00
|
|
|
|
function errnode(e::mx.MXError)
|
|
|
|
|
m = match(r"Error in (\w+):", e.msg)
|
|
|
|
|
m == nothing && return
|
|
|
|
|
Symbol(m.captures[1])
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
macro mxerr(stk, ex)
|
|
|
|
|
:(try
|
|
|
|
|
$(esc(ex))
|
|
|
|
|
catch e
|
|
|
|
|
(isa(e, mx.MXError) && (node = errnode(e)) != nothing) || rethrow()
|
|
|
|
|
stk = $(esc(stk))
|
2017-01-29 18:05:03 +00:00
|
|
|
|
throw(DataFlow.Exception(e, DataFlow.totrace(stk[node])))
|
2017-01-29 11:59:37 +00:00
|
|
|
|
end)
|
|
|
|
|
end
|