From 6eb2ec154b9f2a2040d0879d1751b200fdbde6e0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 6 Nov 2017 12:01:47 +0000 Subject: [PATCH] sentiment treebank loader --- src/Flux.jl | 3 +-- src/batches/tree.jl | 18 ++++++++++++++++-- src/data/Data.jl | 2 ++ src/data/sentiment.jl | 38 +++++++++++++++++++++++++++++++------- 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index acefff19..67b0378a 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -29,8 +29,7 @@ include("layers/basic.jl") include("layers/recurrent.jl") include("layers/normalisation.jl") +include("batches/Batches.jl") include("data/Data.jl") -include("batches/Batches.jl") - end # module diff --git a/src/batches/tree.jl b/src/batches/tree.jl index 29e59c52..f2ece48f 100644 --- a/src/batches/tree.jl +++ b/src/batches/tree.jl @@ -5,8 +5,22 @@ struct Tree{T} children::Vector{Tree{T}} end -Tree(x::T, xs::Vector{Tree{T}} = Tree{T}[]) where T = Tree{T}(x, xs) -Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...]) +Tree{T}(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...]) +Tree{T}(x) where T = Tree(convert(T, x)) + +Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...) AbstractTrees.children(t::Tree) = t.children AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value) + +function Base.show(io::IO, t::Tree) + println(io, typeof(t)) + print_tree(io, t) +end + +using Juno + +@render Juno.Inline t::Tree begin + render(t) = Juno.Tree(t.value, render.(t.children)) + Juno.Tree(typeof(t), [render(t)]) +end diff --git a/src/data/Data.jl b/src/data/Data.jl index 631e45e5..2acccfe2 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -1,5 +1,7 @@ module Data +import ..Flux + export CMUDict, cmudict deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) diff --git a/src/data/sentiment.jl b/src/data/sentiment.jl index 7917f302..8ac7a5a1 100644 --- a/src/data/sentiment.jl +++ b/src/data/sentiment.jl @@ -1,21 +1,45 @@ module Sentiment +using ZipFile using ..Data: deps function load() - isfile(deps("stanfordSentimentTreebank.zip")) || - download("http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip", - deps("stanfordSentimentTreebank.zip")) + isfile(deps("sentiment.zip")) || + download("https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", + deps("sentiment.zip")) return end getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)] -function loadtext() - r = ZipFile.Reader(deps("stanfordSentimentTreebank.zip")) - sentences = readstring(getfile(r, "stanfordSentimentTreebank/datasetSentences.txt")) +function getfile(name) + r = ZipFile.Reader(deps("sentiment.zip")) + text = readstring(getfile(r, "trees/$name")) close(r) - return sentences + return text end +using ..Flux.Batches + +totree_(n, w) = Tree{Any}((parse(Int, n), w)) +totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b)) +totree(t::Expr) = totree_(t.args...) + +function parsetree(s) + s = replace(s, r"\$", s -> "\\\$") + s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"") + s = replace(s, " ", ", ") + return totree(parse(s)) +end + +function gettrees(name) + load() + ss = split(getfile("$name.txt"), '\n', keep = false) + return parsetree.(ss) +end + +train() = gettrees("train") +test() = gettrees("test") +dev() = gettrees("dev") + end