diff --git a/src/Flux.jl b/src/Flux.jl index f77669a8..f9b8f3f4 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,6 +21,7 @@ using .Optimise include("graph/Graph.jl") using .Graph +import .Graph: graph include("utils.jl") include("onehot.jl") diff --git a/src/graph/Graph.jl b/src/graph/Graph.jl index 5db0435d..a452200a 100644 --- a/src/graph/Graph.jl +++ b/src/graph/Graph.jl @@ -2,6 +2,8 @@ module Graph export @net +graph(x) = nothing + macro net(ex) esc(ex) end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 06a5b162..ec55f9f6 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -1,3 +1,5 @@ +using DataFlow: vertex, inputnode + """ Chain(layers...) @@ -22,12 +24,14 @@ end @forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! @forward Chain.layers Base.start, Base.next, Base.done +Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) + children(c::Chain) = c.layers mapchildren(f, c::Chain) = Chain(f.(c.layers)...) (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) -Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) +graph(s::Chain) = foldl((v, m) -> vertex(m, v), inputnode(1), s.layers) function Base.show(io::IO, c::Chain) print(io, "Chain(")