chain graph
This commit is contained in:
parent
2e2684b051
commit
e9c781cd10
@ -21,6 +21,7 @@ using .Optimise
|
|||||||
|
|
||||||
include("graph/Graph.jl")
|
include("graph/Graph.jl")
|
||||||
using .Graph
|
using .Graph
|
||||||
|
import .Graph: graph
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
@ -2,6 +2,8 @@ module Graph
|
|||||||
|
|
||||||
export @net
|
export @net
|
||||||
|
|
||||||
|
graph(x) = nothing
|
||||||
|
|
||||||
macro net(ex)
|
macro net(ex)
|
||||||
esc(ex)
|
esc(ex)
|
||||||
end
|
end
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
using DataFlow: vertex, inputnode
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Chain(layers...)
|
Chain(layers...)
|
||||||
|
|
||||||
@ -22,12 +24,14 @@ end
|
|||||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||||
@forward Chain.layers Base.start, Base.next, Base.done
|
@forward Chain.layers Base.start, Base.next, Base.done
|
||||||
|
|
||||||
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
|
|
||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
graph(s::Chain) = foldl((v, m) -> vertex(m, v), inputnode(1), s.layers)
|
||||||
|
|
||||||
function Base.show(io::IO, c::Chain)
|
function Base.show(io::IO, c::Chain)
|
||||||
print(io, "Chain(")
|
print(io, "Chain(")
|
||||||
|
Loading…
Reference in New Issue
Block a user