Merge branch 'treebank'
This commit is contained in:
commit
34217b1fa2
2
REQUIRE
2
REQUIRE
@ -7,6 +7,8 @@ Requires
|
|||||||
Adapt
|
Adapt
|
||||||
GZip
|
GZip
|
||||||
Colors
|
Colors
|
||||||
|
ZipFile
|
||||||
|
AbstractTrees
|
||||||
|
|
||||||
# AD
|
# AD
|
||||||
ForwardDiff 0.5.0
|
ForwardDiff 0.5.0
|
||||||
|
@ -34,10 +34,11 @@ include("layers/conv.jl")
|
|||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
|
|
||||||
include("jit/JIT.jl")
|
include("batches/Batches.jl")
|
||||||
|
|
||||||
include("data/Data.jl")
|
include("data/Data.jl")
|
||||||
|
|
||||||
|
include("jit/JIT.jl")
|
||||||
|
|
||||||
@require CuArrays include("cuda/cuda.jl")
|
@require CuArrays include("cuda/cuda.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
9
src/batches/Batches.jl
Normal file
9
src/batches/Batches.jl
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
module Batches
|
||||||
|
|
||||||
|
import ..Flux
|
||||||
|
|
||||||
|
export Tree
|
||||||
|
|
||||||
|
include("tree.jl")
|
||||||
|
|
||||||
|
end
|
42
src/batches/tree.jl
Normal file
42
src/batches/tree.jl
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
using AbstractTrees
|
||||||
|
|
||||||
|
struct Tree{T}
|
||||||
|
value::T
|
||||||
|
children::Vector{Tree{T}}
|
||||||
|
end
|
||||||
|
|
||||||
|
Tree{T}(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...])
|
||||||
|
Tree{T}(x) where T = Tree(convert(T, x))
|
||||||
|
|
||||||
|
Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...)
|
||||||
|
|
||||||
|
AbstractTrees.children(t::Tree) = t.children
|
||||||
|
AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value)
|
||||||
|
|
||||||
|
Base.show(io::IO, t::Type{Tree}) = print(io, "Tree")
|
||||||
|
Base.show(io::IO, t::Type{Tree{T}}) where T = print(io, "Tree{", T, "}")
|
||||||
|
|
||||||
|
function Base.show(io::IO, t::Tree)
|
||||||
|
println(io, typeof(t))
|
||||||
|
print_tree(io, t)
|
||||||
|
end
|
||||||
|
|
||||||
|
using Juno
|
||||||
|
|
||||||
|
@render Juno.Inline t::Tree begin
|
||||||
|
render(t) = Juno.Tree(t.value, render.(t.children))
|
||||||
|
Juno.Tree(typeof(t), [render(t)])
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.getindex(t::Tree, i::Integer) = t.children[i]
|
||||||
|
Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...]
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
|
||||||
|
isleaf(t) = isempty(children(t))
|
||||||
|
|
||||||
|
leaves(xs::Tree) = map(x -> x.value, Leaves(xs))
|
||||||
|
|
||||||
|
Base.map(f, t::Tree, ts::Tree...) =
|
||||||
|
Tree{Any}(f(map(t -> t.value, (t, ts...))...),
|
||||||
|
[map(f, chs...) for chs in zip(map(t -> t.children, (t, ts...))...)]...)
|
@ -1,5 +1,7 @@
|
|||||||
module Data
|
module Data
|
||||||
|
|
||||||
|
import ..Flux
|
||||||
|
|
||||||
export CMUDict, cmudict
|
export CMUDict, cmudict
|
||||||
|
|
||||||
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
|
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
|
||||||
@ -12,4 +14,7 @@ include("mnist.jl")
|
|||||||
include("cmudict.jl")
|
include("cmudict.jl")
|
||||||
using .CMUDict
|
using .CMUDict
|
||||||
|
|
||||||
|
include("sentiment.jl")
|
||||||
|
using .Sentiment
|
||||||
|
|
||||||
end
|
end
|
||||||
|
45
src/data/sentiment.jl
Normal file
45
src/data/sentiment.jl
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
module Sentiment
|
||||||
|
|
||||||
|
using ZipFile
|
||||||
|
using ..Data: deps
|
||||||
|
|
||||||
|
function load()
|
||||||
|
isfile(deps("sentiment.zip")) ||
|
||||||
|
download("https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip",
|
||||||
|
deps("sentiment.zip"))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)]
|
||||||
|
|
||||||
|
function getfile(name)
|
||||||
|
r = ZipFile.Reader(deps("sentiment.zip"))
|
||||||
|
text = readstring(getfile(r, "trees/$name"))
|
||||||
|
close(r)
|
||||||
|
return text
|
||||||
|
end
|
||||||
|
|
||||||
|
using ..Flux.Batches
|
||||||
|
|
||||||
|
totree_(n, w) = Tree{Any}((parse(Int, n), w))
|
||||||
|
totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b))
|
||||||
|
totree(t::Expr) = totree_(t.args...)
|
||||||
|
|
||||||
|
function parsetree(s)
|
||||||
|
s = replace(s, r"\$", s -> "\\\$")
|
||||||
|
s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"")
|
||||||
|
s = replace(s, " ", ", ")
|
||||||
|
return totree(parse(s))
|
||||||
|
end
|
||||||
|
|
||||||
|
function gettrees(name)
|
||||||
|
load()
|
||||||
|
ss = split(getfile("$name.txt"), '\n', keep = false)
|
||||||
|
return parsetree.(ss)
|
||||||
|
end
|
||||||
|
|
||||||
|
train() = gettrees("train")
|
||||||
|
test() = gettrees("test")
|
||||||
|
dev() = gettrees("dev")
|
||||||
|
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user