diff --git a/src/batches/tree.jl b/src/batches/tree.jl index f2ece48f..5067714a 100644 --- a/src/batches/tree.jl +++ b/src/batches/tree.jl @@ -13,6 +13,9 @@ Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...) AbstractTrees.children(t::Tree) = t.children AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value) +Base.show(io::IO, t::Type{Tree}) = print(io, "Tree") +Base.show(io::IO, t::Type{Tree{T}}) where T = print(io, "Tree{", T, "}") + function Base.show(io::IO, t::Tree) println(io, typeof(t)) print_tree(io, t) @@ -24,3 +27,16 @@ using Juno render(t) = Juno.Tree(t.value, render.(t.children)) Juno.Tree(typeof(t), [render(t)]) end + +Base.getindex(t::Tree, i::Integer) = t.children[i] +Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...] + +# Utilities + +isleaf(t) = isempty(children(t)) + +leaves(xs::Tree) = map(x -> x.value, Leaves(xs)) + +Base.map(f, t::Tree, ts::Tree...) = + Tree{Any}(f(map(t -> t.value, (t, ts...))...), + [map(f, chs...) for chs in zip(map(t -> t.children, (t, ts...))...)]...)