Flux.jl/src/backend/mxnet/graph.jl

33 lines
842 B
Julia
Raw Normal View History

2016-08-22 20:13:28 +00:00
cvalue(x) = x
cvalue(c::Constant) = c.value
cvalue(v::Vertex) = cvalue(value(v))
graph(vars, model, args...) = node(model, args...)
graph(vars, x::mx.SymbolicNode) = x
# TODO: detect parameters used more than once
2016-08-23 13:28:54 +00:00
function graph{T<:AArray}(vars, p::Flux.Param{T})
value = p.x
2016-08-22 20:13:28 +00:00
id = gensym()
vars[id] = value
return mx.Variable(id)
end
function graph(vars, model::Model, args...)
g = Flux.graph(model)
g = Flow.mapconst(g) do x
!isa(x, Flux.Parameter) ? x :
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
end
postwalk(g) do v
vertex(graph(vars, cvalue(v), cvalue.(inputs(v))...))
end |> value
end
# Built-in implemenations
node(::typeof(*), args...) = mx.dot(args...)
node(::typeof(+), args...) = mx.broadcast_plus(args...)
2016-08-23 13:28:54 +00:00
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)