basic line node handling
This commit is contained in:
parent
17449e15a3
commit
5f27e30e68
@ -1,5 +1,6 @@
|
|||||||
using Base: @get!
|
using Base: @get!
|
||||||
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, interplambda, interpconst
|
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
|
||||||
|
interplambda, interpconst, interpline, stack
|
||||||
using Flux: interpmap
|
using Flux: interpmap
|
||||||
using TensorFlow: RawTensor
|
using TensorFlow: RawTensor
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ function interp(ctx, model, args...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function tograph(model, args...)
|
function tograph(model, args...)
|
||||||
ctx = Context(interplambda(interptuple(interpmap(interp))), params = ObjectIdDict())
|
ctx = Context(interpline(interplambda(interptuple(interpmap(interp)))), params = ObjectIdDict())
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
return ctx[:params], out
|
return ctx[:params], out
|
||||||
end
|
end
|
||||||
|
@ -5,6 +5,6 @@ d = Affine(10, 20)
|
|||||||
@test d(xs) == xs*d.W.x + d.b.x
|
@test d(xs) == xs*d.W.x + d.b.x
|
||||||
|
|
||||||
let
|
let
|
||||||
@capture(syntax(d), x_[1] * W_ + b_)
|
@capture(syntax(d), _Line(x_[1] * W_ + b_))
|
||||||
@test isa(x, Input) && isa(W, Param) && isa(b, Param)
|
@test isa(x, Input) && isa(W, Param) && isa(b, Param)
|
||||||
end
|
end
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using Flux, DataFlow, MacroTools, Base.Test
|
using Flux, DataFlow, MacroTools, Base.Test
|
||||||
using Flux: graph, Param
|
using Flux: graph, Param
|
||||||
using DataFlow: Input
|
using DataFlow: Input, Line
|
||||||
|
|
||||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||||
syntax(x) = syntax(graph(x))
|
syntax(x) = syntax(graph(x))
|
||||||
|
Loading…
Reference in New Issue
Block a user