pull out graph-based affine layer
This commit is contained in:
parent
b18d8cd08f
commit
c7a07562d0
@ -1,8 +1,17 @@
|
|||||||
using DataFlow, MacroTools
|
using DataFlow, MacroTools
|
||||||
using Flux: Affine, Param, Recurrent, squeeze, unsqueeze, stack
|
using Flux: Param, Recurrent, squeeze, unsqueeze, stack
|
||||||
using Flux.Compiler: @net, graph
|
using Flux.Compiler: @net, graph
|
||||||
using DataFlow: Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
|
@net type Affine
|
||||||
|
W
|
||||||
|
b
|
||||||
|
x -> x*W .+ b
|
||||||
|
end
|
||||||
|
|
||||||
|
Affine(in::Integer, out::Integer; init = Flux.initn) =
|
||||||
|
Affine(init(in, out), init(1, out))
|
||||||
|
|
||||||
@net type TLP
|
@net type TLP
|
||||||
first
|
first
|
||||||
second
|
second
|
||||||
@ -42,7 +51,7 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
|||||||
e
|
e
|
||||||
end
|
end
|
||||||
@test e.trace[end].func == :TLP
|
@test e.trace[end].func == :TLP
|
||||||
@test e.trace[end-1].func == Symbol("Flux.Affine")
|
@test e.trace[end-1].func == Symbol("Affine")
|
||||||
end
|
end
|
||||||
|
|
||||||
function apply(model, xs, state)
|
function apply(model, xs, state)
|
||||||
|
Loading…
Reference in New Issue
Block a user