diff --git a/REQUIRE b/REQUIRE index b31bc6ad..b75b59ca 100644 --- a/REQUIRE +++ b/REQUIRE @@ -7,6 +7,8 @@ Requires Adapt GZip Colors +ZipFile +AbstractTrees # AD ForwardDiff 0.5.0 diff --git a/src/Flux.jl b/src/Flux.jl index 88b2108e..159f7325 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -34,10 +34,11 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalisation.jl") -include("jit/JIT.jl") - +include("batches/Batches.jl") include("data/Data.jl") +include("jit/JIT.jl") + @require CuArrays include("cuda/cuda.jl") end # module diff --git a/src/batches/Batches.jl b/src/batches/Batches.jl new file mode 100644 index 00000000..a2424549 --- /dev/null +++ b/src/batches/Batches.jl @@ -0,0 +1,9 @@ +module Batches + +import ..Flux + +export Tree + +include("tree.jl") + +end diff --git a/src/batches/tree.jl b/src/batches/tree.jl new file mode 100644 index 00000000..5067714a --- /dev/null +++ b/src/batches/tree.jl @@ -0,0 +1,42 @@ +using AbstractTrees + +struct Tree{T} + value::T + children::Vector{Tree{T}} +end + +Tree{T}(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...]) +Tree{T}(x) where T = Tree(convert(T, x)) + +Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...) + +AbstractTrees.children(t::Tree) = t.children +AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value) + +Base.show(io::IO, t::Type{Tree}) = print(io, "Tree") +Base.show(io::IO, t::Type{Tree{T}}) where T = print(io, "Tree{", T, "}") + +function Base.show(io::IO, t::Tree) + println(io, typeof(t)) + print_tree(io, t) +end + +using Juno + +@render Juno.Inline t::Tree begin + render(t) = Juno.Tree(t.value, render.(t.children)) + Juno.Tree(typeof(t), [render(t)]) +end + +Base.getindex(t::Tree, i::Integer) = t.children[i] +Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...] + +# Utilities + +isleaf(t) = isempty(children(t)) + +leaves(xs::Tree) = map(x -> x.value, Leaves(xs)) + +Base.map(f, t::Tree, ts::Tree...) = + Tree{Any}(f(map(t -> t.value, (t, ts...))...), + [map(f, chs...) for chs in zip(map(t -> t.children, (t, ts...))...)]...) diff --git a/src/data/Data.jl b/src/data/Data.jl index 2844d0ae..e2118029 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -1,5 +1,7 @@ module Data +import ..Flux + export CMUDict, cmudict deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) @@ -12,4 +14,7 @@ include("mnist.jl") include("cmudict.jl") using .CMUDict +include("sentiment.jl") +using .Sentiment + end diff --git a/src/data/sentiment.jl b/src/data/sentiment.jl new file mode 100644 index 00000000..8ac7a5a1 --- /dev/null +++ b/src/data/sentiment.jl @@ -0,0 +1,45 @@ +module Sentiment + +using ZipFile +using ..Data: deps + +function load() + isfile(deps("sentiment.zip")) || + download("https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", + deps("sentiment.zip")) + return +end + +getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)] + +function getfile(name) + r = ZipFile.Reader(deps("sentiment.zip")) + text = readstring(getfile(r, "trees/$name")) + close(r) + return text +end + +using ..Flux.Batches + +totree_(n, w) = Tree{Any}((parse(Int, n), w)) +totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b)) +totree(t::Expr) = totree_(t.args...) + +function parsetree(s) + s = replace(s, r"\$", s -> "\\\$") + s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"") + s = replace(s, " ", ", ") + return totree(parse(s)) +end + +function gettrees(name) + load() + ss = split(getfile("$name.txt"), '\n', keep = false) + return parsetree.(ss) +end + +train() = gettrees("train") +test() = gettrees("test") +dev() = gettrees("dev") + +end