chain graph

This commit is contained in:
Mike J Innes 2017-10-05 19:19:25 +01:00
parent 2e2684b051
commit e9c781cd10
3 changed files with 8 additions and 1 deletions

View File

@ -21,6 +21,7 @@ using .Optimise
include("graph/Graph.jl")
using .Graph
import .Graph: graph
include("utils.jl")
include("onehot.jl")

View File

@ -2,6 +2,8 @@ module Graph
export @net
graph(x) = nothing
macro net(ex)
esc(ex)
end

View File

@ -1,3 +1,5 @@
using DataFlow: vertex, inputnode
"""
Chain(layers...)
@ -22,12 +24,14 @@ end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward Chain.layers Base.start, Base.next, Base.done
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
graph(s::Chain) = foldl((v, m) -> vertex(m, v), inputnode(1), s.layers)
function Base.show(io::IO, c::Chain)
print(io, "Chain(")