109 lines
3.2 KiB
Julia
109 lines
3.2 KiB
Julia
function nodename(s::mx.SymbolicNode)
|
||
name = Ref{mx.char_p}(0)
|
||
success = Ref(0)
|
||
mx.@mxcall(:MXSymbolGetName, (mx.MX_handle, Ref{mx.char_p}, Ref{Int}), s.handle.value, name, success)
|
||
@assert success[] != -1
|
||
return Symbol(unsafe_wrap(String, name[]))
|
||
end
|
||
|
||
using Base: @get!
|
||
using DataFlow: Constant, constant
|
||
using DataFlow.Interpreter
|
||
using DataFlow.Interpreter: Exception, totrace
|
||
using Flux: imap
|
||
|
||
# TODO: implement Julia's type promotion rules
|
||
|
||
node(x::Tuple) = map(node, x)
|
||
node(x::mx.SymbolicNode) = x
|
||
|
||
graph(::typeof(tuple), args...) = (args...,)
|
||
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||
graph(::typeof(*), x, W) = mx.dot(transpose(W), x) # Adjustments for batching
|
||
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||
graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu)
|
||
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh)
|
||
graph(::typeof(flatten), x) = mx.Flatten(data = x)
|
||
|
||
graph(::typeof(softmax), xs) =
|
||
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||
|
||
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
|
||
graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
||
|
||
graph(::Input, x) = x
|
||
|
||
graph(ctx::Context, d::Affine, x) =
|
||
register(ctx,
|
||
mx.FullyConnected(data = x,
|
||
num_hidden = size(d.W.x, 2),
|
||
weight = var(ctx, d.W),
|
||
bias = var(ctx, d.b, size(d.b, 2))))
|
||
|
||
# TODO: use actual params}
|
||
graph(ctx::Context, c::Conv2D, x) =
|
||
mx.Convolution(data = x,
|
||
kernel = size(c.filter, 1, 2),
|
||
num_filter = size(c.filter, 4),
|
||
stride = c.stride)
|
||
|
||
graph(ctx::Context, p::MaxPool, x) =
|
||
mx.Pooling(data = x,
|
||
pool_type = :max,
|
||
kernel = p.size,
|
||
stride = p.stride)
|
||
|
||
function register(ctx::Context, node::mx.SymbolicNode)
|
||
ctx[:stacks][nodename(node)] = stack(ctx)
|
||
return node
|
||
end
|
||
|
||
register(ctx::Context, node) = node
|
||
|
||
function var(ctx::Context, p::Flux.Param, size = nothing)
|
||
id = gensym()
|
||
ctx[:params][id] = size == nothing ? p.x : reshape(p.x, size...)
|
||
return mx.Variable(id)
|
||
end
|
||
|
||
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
|
||
|
||
graph(ctx::Context, p::Constant) = node(p.value)
|
||
|
||
function graph(ctx::Context, model, args...)
|
||
g = Flux.graph(model)
|
||
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
|
||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||
interpret(ctx, g, args...)
|
||
end
|
||
|
||
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
||
|
||
function tograph(model, args...)
|
||
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph′),
|
||
params = Dict(), stacks = Dict())
|
||
out = @ithrow graph(ctx, model, args...)
|
||
return ctx[:params], ctx[:stacks], out
|
||
end
|
||
|
||
# Error Handling
|
||
|
||
using Juno
|
||
Juno.errmsg(e::mx.MXError) = e.msg
|
||
|
||
function errnode(e::mx.MXError)
|
||
m = match(r"Error in (\w+)", e.msg)
|
||
m == nothing && return
|
||
Symbol(m.captures[1])
|
||
end
|
||
|
||
macro mxerr(stk, ex)
|
||
:(try
|
||
$(esc(ex))
|
||
catch e
|
||
(isa(e, mx.MXError) && (node = errnode(e)) != nothing) || rethrow()
|
||
stk = $(esc(stk))
|
||
throw(Exception(e, totrace(stk[node])))
|
||
end)
|
||
end
|