sentiment treebank loader

This commit is contained in:
Mike J Innes 2017-11-06 12:01:47 +00:00
parent 8777362eee
commit 6eb2ec154b
4 changed files with 50 additions and 11 deletions

View File

@ -29,8 +29,7 @@ include("layers/basic.jl")
include("layers/recurrent.jl")
include("layers/normalisation.jl")
include("batches/Batches.jl")
include("data/Data.jl")
include("batches/Batches.jl")
end # module

View File

@ -5,8 +5,22 @@ struct Tree{T}
children::Vector{Tree{T}}
end
Tree(x::T, xs::Vector{Tree{T}} = Tree{T}[]) where T = Tree{T}(x, xs)
Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...])
Tree{T}(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...])
Tree{T}(x) where T = Tree(convert(T, x))
Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...)
AbstractTrees.children(t::Tree) = t.children
AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value)
function Base.show(io::IO, t::Tree)
println(io, typeof(t))
print_tree(io, t)
end
using Juno
@render Juno.Inline t::Tree begin
render(t) = Juno.Tree(t.value, render.(t.children))
Juno.Tree(typeof(t), [render(t)])
end

View File

@ -1,5 +1,7 @@
module Data
import ..Flux
export CMUDict, cmudict
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)

View File

@ -1,21 +1,45 @@
module Sentiment
using ZipFile
using ..Data: deps
function load()
isfile(deps("stanfordSentimentTreebank.zip")) ||
download("http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip",
deps("stanfordSentimentTreebank.zip"))
isfile(deps("sentiment.zip")) ||
download("https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
deps("sentiment.zip"))
return
end
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
function loadtext()
r = ZipFile.Reader(deps("stanfordSentimentTreebank.zip"))
sentences = readstring(getfile(r, "stanfordSentimentTreebank/datasetSentences.txt"))
function getfile(name)
r = ZipFile.Reader(deps("sentiment.zip"))
text = readstring(getfile(r, "trees/$name"))
close(r)
return sentences
return text
end
using ..Flux.Batches
totree_(n, w) = Tree{Any}((parse(Int, n), w))
totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b))
totree(t::Expr) = totree_(t.args...)
function parsetree(s)
s = replace(s, r"\$", s -> "\\\$")
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
s = replace(s, " ", ", ")
return totree(parse(s))
end
function gettrees(name)
load()
ss = split(getfile("$name.txt"), '\n', keep = false)
return parsetree.(ss)
end
train() = gettrees("train")
test() = gettrees("test")
dev() = gettrees("dev")
end