Merge branch 'treebank'
This commit is contained in:
commit
34217b1fa2
2
REQUIRE
2
REQUIRE
@ -7,6 +7,8 @@ Requires
|
||||
Adapt
|
||||
GZip
|
||||
Colors
|
||||
ZipFile
|
||||
AbstractTrees
|
||||
|
||||
# AD
|
||||
ForwardDiff 0.5.0
|
||||
|
@ -34,10 +34,11 @@ include("layers/conv.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/normalisation.jl")
|
||||
|
||||
include("jit/JIT.jl")
|
||||
|
||||
include("batches/Batches.jl")
|
||||
include("data/Data.jl")
|
||||
|
||||
include("jit/JIT.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
9
src/batches/Batches.jl
Normal file
9
src/batches/Batches.jl
Normal file
@ -0,0 +1,9 @@
|
||||
module Batches
|
||||
|
||||
import ..Flux
|
||||
|
||||
export Tree
|
||||
|
||||
include("tree.jl")
|
||||
|
||||
end
|
42
src/batches/tree.jl
Normal file
42
src/batches/tree.jl
Normal file
@ -0,0 +1,42 @@
|
||||
using AbstractTrees
|
||||
|
||||
struct Tree{T}
|
||||
value::T
|
||||
children::Vector{Tree{T}}
|
||||
end
|
||||
|
||||
Tree{T}(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...])
|
||||
Tree{T}(x) where T = Tree(convert(T, x))
|
||||
|
||||
Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...)
|
||||
|
||||
AbstractTrees.children(t::Tree) = t.children
|
||||
AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value)
|
||||
|
||||
Base.show(io::IO, t::Type{Tree}) = print(io, "Tree")
|
||||
Base.show(io::IO, t::Type{Tree{T}}) where T = print(io, "Tree{", T, "}")
|
||||
|
||||
function Base.show(io::IO, t::Tree)
|
||||
println(io, typeof(t))
|
||||
print_tree(io, t)
|
||||
end
|
||||
|
||||
using Juno
|
||||
|
||||
@render Juno.Inline t::Tree begin
|
||||
render(t) = Juno.Tree(t.value, render.(t.children))
|
||||
Juno.Tree(typeof(t), [render(t)])
|
||||
end
|
||||
|
||||
Base.getindex(t::Tree, i::Integer) = t.children[i]
|
||||
Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...]
|
||||
|
||||
# Utilities
|
||||
|
||||
isleaf(t) = isempty(children(t))
|
||||
|
||||
leaves(xs::Tree) = map(x -> x.value, Leaves(xs))
|
||||
|
||||
Base.map(f, t::Tree, ts::Tree...) =
|
||||
Tree{Any}(f(map(t -> t.value, (t, ts...))...),
|
||||
[map(f, chs...) for chs in zip(map(t -> t.children, (t, ts...))...)]...)
|
@ -1,5 +1,7 @@
|
||||
module Data
|
||||
|
||||
import ..Flux
|
||||
|
||||
export CMUDict, cmudict
|
||||
|
||||
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
|
||||
@ -12,4 +14,7 @@ include("mnist.jl")
|
||||
include("cmudict.jl")
|
||||
using .CMUDict
|
||||
|
||||
include("sentiment.jl")
|
||||
using .Sentiment
|
||||
|
||||
end
|
||||
|
45
src/data/sentiment.jl
Normal file
45
src/data/sentiment.jl
Normal file
@ -0,0 +1,45 @@
|
||||
module Sentiment
|
||||
|
||||
using ZipFile
|
||||
using ..Data: deps
|
||||
|
||||
function load()
|
||||
isfile(deps("sentiment.zip")) ||
|
||||
download("https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||
deps("sentiment.zip"))
|
||||
return
|
||||
end
|
||||
|
||||
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
||||
|
||||
function getfile(name)
|
||||
r = ZipFile.Reader(deps("sentiment.zip"))
|
||||
text = readstring(getfile(r, "trees/$name"))
|
||||
close(r)
|
||||
return text
|
||||
end
|
||||
|
||||
using ..Flux.Batches
|
||||
|
||||
totree_(n, w) = Tree{Any}((parse(Int, n), w))
|
||||
totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b))
|
||||
totree(t::Expr) = totree_(t.args...)
|
||||
|
||||
function parsetree(s)
|
||||
s = replace(s, r"\$", s -> "\\\$")
|
||||
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
||||
s = replace(s, " ", ", ")
|
||||
return totree(parse(s))
|
||||
end
|
||||
|
||||
function gettrees(name)
|
||||
load()
|
||||
ss = split(getfile("$name.txt"), '\n', keep = false)
|
||||
return parsetree.(ss)
|
||||
end
|
||||
|
||||
train() = gettrees("train")
|
||||
test() = gettrees("test")
|
||||
dev() = gettrees("dev")
|
||||
|
||||
end
|
Loading…
Reference in New Issue
Block a user