fix hintify
This commit is contained in:
parent
07433c13bd
commit
cd86dfdf07
@ -4,7 +4,7 @@ using MacroTools, Lazy, DataFlow, Juno
|
|||||||
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||||
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
||||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||||
spliceinputs, bumpinputs, Frame, applylines
|
spliceinputs, bumpinputs, Line, Frame, applylines
|
||||||
using Juno: Tree, Row
|
using Juno: Tree, Row
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
@ -15,12 +15,12 @@ end
|
|||||||
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
|
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
|
||||||
ihint(f, args...) = f(args...)
|
ihint(f, args...) = f(args...)
|
||||||
|
|
||||||
hintify(c::Constant) = hintify(state(c.value))
|
hintify(ctx, c::Constant) = hintify(ctx, state(c.value))
|
||||||
hintify(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
||||||
|
|
||||||
interpshape = mux(ilinev, ihint, iargs, ituple, hintify)
|
interpshape = mux(ilinev, ihint, iargs, ituple, hintify)
|
||||||
|
|
||||||
function hintify(f, xs...)
|
function hintify(ctx, f, xs...)
|
||||||
sh = infer(f, map(gethint, xs)...)
|
sh = infer(f, map(gethint, xs)...)
|
||||||
sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) :
|
sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) :
|
||||||
!any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) :
|
!any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) :
|
||||||
|
@ -21,4 +21,5 @@ end
|
|||||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||||
tlp = TLP(a1, a2)
|
tlp = TLP(a1, a2)
|
||||||
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
||||||
|
@test Flux.infer(tlp, (1, 10)) == (1,15)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user