diff --git a/test/compiler.jl b/test/compiler.jl index 2064221f..7af6bad5 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,8 +1,17 @@ using DataFlow, MacroTools -using Flux: Affine, Param, Recurrent, squeeze, unsqueeze, stack +using Flux: Param, Recurrent, squeeze, unsqueeze, stack using Flux.Compiler: @net, graph using DataFlow: Line, Frame +@net type Affine + W + b + x -> x*W .+ b +end + +Affine(in::Integer, out::Integer; init = Flux.initn) = + Affine(init(in, out), init(1, out)) + @net type TLP first second @@ -42,7 +51,7 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15)) e end @test e.trace[end].func == :TLP - @test e.trace[end-1].func == Symbol("Flux.Affine") + @test e.trace[end-1].func == Symbol("Affine") end function apply(model, xs, state)