basic line node handling

This commit is contained in:
Mike J Innes 2016-12-20 15:44:00 +00:00
parent 17449e15a3
commit 5f27e30e68
3 changed files with 5 additions and 4 deletions

View File

@ -1,5 +1,6 @@
using Base: @get!
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, interplambda, interpconst
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
interplambda, interpconst, interpline, stack
using Flux: interpmap
using TensorFlow: RawTensor
@ -54,7 +55,7 @@ function interp(ctx, model, args...)
end
function tograph(model, args...)
ctx = Context(interplambda(interptuple(interpmap(interp))), params = ObjectIdDict())
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), params = ObjectIdDict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out
end

View File

@ -5,6 +5,6 @@ d = Affine(10, 20)
@test d(xs) == xs*d.W.x + d.b.x
let
@capture(syntax(d), x_[1] * W_ + b_)
@capture(syntax(d), _Line(x_[1] * W_ + b_))
@test isa(x, Input) && isa(W, Param) && isa(b, Param)
end

View File

@ -1,6 +1,6 @@
using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param
using DataFlow: Input
using DataFlow: Input, Line
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
syntax(x) = syntax(graph(x))