basic line node handling
This commit is contained in:
parent
17449e15a3
commit
5f27e30e68
|
@ -1,5 +1,6 @@
|
|||
using Base: @get!
|
||||
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, interplambda, interpconst
|
||||
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
|
||||
interplambda, interpconst, interpline, stack
|
||||
using Flux: interpmap
|
||||
using TensorFlow: RawTensor
|
||||
|
||||
|
@ -54,7 +55,7 @@ function interp(ctx, model, args...)
|
|||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
ctx = Context(interplambda(interptuple(interpmap(interp))), params = ObjectIdDict())
|
||||
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), params = ObjectIdDict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], out
|
||||
end
|
||||
|
|
|
@ -5,6 +5,6 @@ d = Affine(10, 20)
|
|||
@test d(xs) == xs*d.W.x + d.b.x
|
||||
|
||||
let
|
||||
@capture(syntax(d), x_[1] * W_ + b_)
|
||||
@capture(syntax(d), _Line(x_[1] * W_ + b_))
|
||||
@test isa(x, Input) && isa(W, Param) && isa(b, Param)
|
||||
end
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using Flux, DataFlow, MacroTools, Base.Test
|
||||
using Flux: graph, Param
|
||||
using DataFlow: Input
|
||||
using DataFlow: Input, Line
|
||||
|
||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||
syntax(x) = syntax(graph(x))
|
||||
|
|
Loading…
Reference in New Issue