Flux.jl/src/backend/mxnet/graph.jl

104 lines
2.9 KiB
Julia
Raw Normal View History

2017-01-29 10:39:30 +00:00
function nodename(s::mx.SymbolicNode)
2017-01-28 18:06:58 +00:00
name = Ref{mx.char_p}(0)
success = Ref(0)
mx.@mxcall(:MXSymbolGetName, (mx.MX_handle, Ref{mx.char_p}, Ref{Int}), s.handle.value, name, success)
@assert success[] != -1
return Symbol(unsafe_wrap(String, name[]))
end
2017-01-28 17:02:49 +00:00
using Base: @get!
2017-02-01 06:57:02 +00:00
using DataFlow: Constant, constant
using DataFlow.Interpreter
2017-02-01 14:21:08 +00:00
using DataFlow.Interpreter: Exception, totrace
2017-01-28 17:02:49 +00:00
using Flux: imap
# TODO: implement Julia's type promotion rules
node(x::Tuple) = map(node, x)
node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
2017-02-21 08:50:54 +00:00
graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu)
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh)
2017-01-28 17:02:49 +00:00
graph(::typeof(flatten), x) = mx.Flatten(data = x)
graph(::typeof(softmax), xs) =
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
2017-02-20 21:49:02 +00:00
graph(::typeof(vcat), a...) = graph(cat, 1, a...)
2017-01-28 17:02:49 +00:00
graph(::Input, x) = x
# graph(vars, c::Conv, x) =
# mx.Convolution(data = x,
# kernel = c.size,
# num_filter = c.features,
# stride = c.stride)
#
# graph(vars, p::MaxPool, x) =
# mx.Pooling(data = x,
# pool_type = :max,
# kernel = p.size,
# stride = p.stride)
#
# graph(vars, d::Dense, x) =
# mx.FullyConnected(data = x,
# num_hidden = size(d.W.x, 1),
# weight = graph(vars, d.W),
# bias = graph(vars, d.b))
2017-01-30 17:42:01 +00:00
function register(ctx::Context, node::mx.SymbolicNode)
ctx[:stacks][nodename(node)] = stack(ctx)
2017-01-28 17:02:49 +00:00
return node
end
2017-01-30 17:42:01 +00:00
register(ctx::Context, node) = node
2017-01-30 17:49:18 +00:00
function graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}})
id = gensym()
ctx[:params][id] = p.value.x
return mx.Variable(id)
2017-01-30 17:42:01 +00:00
end
2017-01-30 17:49:18 +00:00
graph(ctx::Context, p::Constant) = node(p.value)
function graph(ctx::Context, model, args...)
2017-01-28 17:02:49 +00:00
g = Flux.graph(model)
2017-02-01 06:26:20 +00:00
g == nothing && return register(ctx, @ithrow ctx graph(model, args...))
2017-01-28 17:02:49 +00:00
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
2017-01-30 17:21:49 +00:00
interpret(ctx, g, args...)
2017-01-28 17:02:49 +00:00
end
2017-02-20 21:49:47 +00:00
graph(ctx::Context, args...) = @ithrow ctx graph(ctx, args...)
2017-01-28 17:02:49 +00:00
function tograph(model, args...)
2017-02-20 21:49:47 +00:00
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph),
2017-01-28 17:02:49 +00:00
params = Dict(), stacks = Dict())
2017-02-01 06:26:20 +00:00
out = @icatch graph(ctx, model, args...)
2017-01-28 17:02:49 +00:00
return ctx[:params], ctx[:stacks], out
end
2017-01-29 11:59:37 +00:00
# Error Handling
2017-01-29 18:05:03 +00:00
using Juno
Juno.errmsg(e::mx.MXError) = e.msg
2017-01-29 11:59:37 +00:00
function errnode(e::mx.MXError)
m = match(r"Error in (\w+):", e.msg)
m == nothing && return
Symbol(m.captures[1])
end
macro mxerr(stk, ex)
:(try
$(esc(ex))
catch e
(isa(e, mx.MXError) && (node = errnode(e)) != nothing) || rethrow()
stk = $(esc(stk))
2017-02-01 14:21:08 +00:00
throw(Exception(e, totrace(stk[node])))
2017-01-29 11:59:37 +00:00
end)
end