From 7e9468d8f868d4beba47fc9df98076968aa689e2 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 2 Nov 2017 11:41:28 +0000 Subject: [PATCH 1/6] treebank skeleton --- REQUIRE | 1 + src/data/Data.jl | 3 +++ src/data/sentiment.jl | 21 +++++++++++++++++++++ 3 files changed, 25 insertions(+) create mode 100644 src/data/sentiment.jl diff --git a/REQUIRE b/REQUIRE index d124b931..ea9cd5e7 100644 --- a/REQUIRE +++ b/REQUIRE @@ -5,3 +5,4 @@ MacroTools 0.3.3 NNlib ForwardDiff Requires +ZipFile diff --git a/src/data/Data.jl b/src/data/Data.jl index ffea729c..631e45e5 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -11,4 +11,7 @@ end include("cmudict.jl") using .CMUDict +include("sentiment.jl") +using .Sentiment + end diff --git a/src/data/sentiment.jl b/src/data/sentiment.jl new file mode 100644 index 00000000..7917f302 --- /dev/null +++ b/src/data/sentiment.jl @@ -0,0 +1,21 @@ +module Sentiment + +using ..Data: deps + +function load() + isfile(deps("stanfordSentimentTreebank.zip")) || + download("http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip", + deps("stanfordSentimentTreebank.zip")) + return +end + +getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)] + +function loadtext() + r = ZipFile.Reader(deps("stanfordSentimentTreebank.zip")) + sentences = readstring(getfile(r, "stanfordSentimentTreebank/datasetSentences.txt")) + close(r) + return sentences +end + +end From 8b05317895e7d4c34a84d1b5d6844d8bc1f092fe Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 2 Nov 2017 12:09:09 +0000 Subject: [PATCH 2/6] basic tree --- REQUIRE | 1 + src/batches/Batches.jl | 1 + src/batches/tree.jl | 12 ++++++++++++ 3 files changed, 14 insertions(+) create mode 100644 src/batches/tree.jl diff --git a/REQUIRE b/REQUIRE index ea9cd5e7..dc772c4e 100644 --- a/REQUIRE +++ b/REQUIRE @@ -6,3 +6,4 @@ NNlib ForwardDiff Requires ZipFile +AbstractTrees diff --git a/src/batches/Batches.jl b/src/batches/Batches.jl index 066f4d1c..5fc3f862 100644 --- a/src/batches/Batches.jl +++ b/src/batches/Batches.jl @@ -3,5 +3,6 @@ module Batches import ..Flux include("batch.jl") +include("tree.jl") end diff --git a/src/batches/tree.jl b/src/batches/tree.jl new file mode 100644 index 00000000..29e59c52 --- /dev/null +++ b/src/batches/tree.jl @@ -0,0 +1,12 @@ +using AbstractTrees + +struct Tree{T} + value::T + children::Vector{Tree{T}} +end + +Tree(x::T, xs::Vector{Tree{T}} = Tree{T}[]) where T = Tree{T}(x, xs) +Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...]) + +AbstractTrees.children(t::Tree) = t.children +AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value) From 8777362eee0c61a64294efc3f117ad824d588573 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 2 Nov 2017 12:11:21 +0000 Subject: [PATCH 3/6] exports --- src/batches/Batches.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/batches/Batches.jl b/src/batches/Batches.jl index 5fc3f862..0146fb00 100644 --- a/src/batches/Batches.jl +++ b/src/batches/Batches.jl @@ -2,6 +2,8 @@ module Batches import ..Flux +export Batch, Tree + include("batch.jl") include("tree.jl") From 6eb2ec154b9f2a2040d0879d1751b200fdbde6e0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 6 Nov 2017 12:01:47 +0000 Subject: [PATCH 4/6] sentiment treebank loader --- src/Flux.jl | 3 +-- src/batches/tree.jl | 18 ++++++++++++++++-- src/data/Data.jl | 2 ++ src/data/sentiment.jl | 38 +++++++++++++++++++++++++++++++------- 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index acefff19..67b0378a 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -29,8 +29,7 @@ include("layers/basic.jl") include("layers/recurrent.jl") include("layers/normalisation.jl") +include("batches/Batches.jl") include("data/Data.jl") -include("batches/Batches.jl") - end # module diff --git a/src/batches/tree.jl b/src/batches/tree.jl index 29e59c52..f2ece48f 100644 --- a/src/batches/tree.jl +++ b/src/batches/tree.jl @@ -5,8 +5,22 @@ struct Tree{T} children::Vector{Tree{T}} end -Tree(x::T, xs::Vector{Tree{T}} = Tree{T}[]) where T = Tree{T}(x, xs) -Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...]) +Tree{T}(x::T, xs::Tree{T}...) where T = Tree{T}(x, [xs...]) +Tree{T}(x) where T = Tree(convert(T, x)) + +Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...) AbstractTrees.children(t::Tree) = t.children AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value) + +function Base.show(io::IO, t::Tree) + println(io, typeof(t)) + print_tree(io, t) +end + +using Juno + +@render Juno.Inline t::Tree begin + render(t) = Juno.Tree(t.value, render.(t.children)) + Juno.Tree(typeof(t), [render(t)]) +end diff --git a/src/data/Data.jl b/src/data/Data.jl index 631e45e5..2acccfe2 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -1,5 +1,7 @@ module Data +import ..Flux + export CMUDict, cmudict deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) diff --git a/src/data/sentiment.jl b/src/data/sentiment.jl index 7917f302..8ac7a5a1 100644 --- a/src/data/sentiment.jl +++ b/src/data/sentiment.jl @@ -1,21 +1,45 @@ module Sentiment +using ZipFile using ..Data: deps function load() - isfile(deps("stanfordSentimentTreebank.zip")) || - download("http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip", - deps("stanfordSentimentTreebank.zip")) + isfile(deps("sentiment.zip")) || + download("https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip", + deps("sentiment.zip")) return end getfile(r, name) = r.files[findfirst(x -> x.name == name, r.files)] -function loadtext() - r = ZipFile.Reader(deps("stanfordSentimentTreebank.zip")) - sentences = readstring(getfile(r, "stanfordSentimentTreebank/datasetSentences.txt")) +function getfile(name) + r = ZipFile.Reader(deps("sentiment.zip")) + text = readstring(getfile(r, "trees/$name")) close(r) - return sentences + return text end +using ..Flux.Batches + +totree_(n, w) = Tree{Any}((parse(Int, n), w)) +totree_(n, a, b) = Tree{Any}((parse(Int, n), nothing), totree(a), totree(b)) +totree(t::Expr) = totree_(t.args...) + +function parsetree(s) + s = replace(s, r"\$", s -> "\\\$") + s = replace(s, r"[^\s\(\)]+", s -> "\"$s\"") + s = replace(s, " ", ", ") + return totree(parse(s)) +end + +function gettrees(name) + load() + ss = split(getfile("$name.txt"), '\n', keep = false) + return parsetree.(ss) +end + +train() = gettrees("train") +test() = gettrees("test") +dev() = gettrees("dev") + end From 752a9e2808884f56b6c9f8f3a8a4706f407692ea Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 7 Nov 2017 10:57:27 +0000 Subject: [PATCH 5/6] tree utilities --- src/batches/tree.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/batches/tree.jl b/src/batches/tree.jl index f2ece48f..5067714a 100644 --- a/src/batches/tree.jl +++ b/src/batches/tree.jl @@ -13,6 +13,9 @@ Tree(x::T, xs::Tree{T}...) where T = Tree{T}(x, xs...) AbstractTrees.children(t::Tree) = t.children AbstractTrees.printnode(io::IO, t::Tree) = show(io, t.value) +Base.show(io::IO, t::Type{Tree}) = print(io, "Tree") +Base.show(io::IO, t::Type{Tree{T}}) where T = print(io, "Tree{", T, "}") + function Base.show(io::IO, t::Tree) println(io, typeof(t)) print_tree(io, t) @@ -24,3 +27,16 @@ using Juno render(t) = Juno.Tree(t.value, render.(t.children)) Juno.Tree(typeof(t), [render(t)]) end + +Base.getindex(t::Tree, i::Integer) = t.children[i] +Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...] + +# Utilities + +isleaf(t) = isempty(children(t)) + +leaves(xs::Tree) = map(x -> x.value, Leaves(xs)) + +Base.map(f, t::Tree, ts::Tree...) = + Tree{Any}(f(map(t -> t.value, (t, ts...))...), + [map(f, chs...) for chs in zip(map(t -> t.children, (t, ts...))...)]...) From ccdc0465466446913d6ccb80ac2152f9f8e20c8d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 9 Nov 2017 14:52:28 +0000 Subject: [PATCH 6/6] fixes #79 --- src/onehot.jl | 11 +++++++++-- src/tracker/Tracker.jl | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 5414773c..f8061063 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -1,3 +1,5 @@ +import Base: * + struct OneHotVector <: AbstractVector{Bool} ix::UInt32 of::UInt32 @@ -7,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix -Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] +A::AbstractMatrix * b::OneHotVector = A[:, b.ix] struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} height::Int @@ -18,7 +20,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i] -Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] +A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -47,3 +49,8 @@ argmax(y::AbstractVector, labels = 1:length(y)) = argmax(y::AbstractMatrix, l...) = squeeze(mapslices(y -> argmax(y, l...), y, 1), 1) + +# Ambiguity hack + +a::TrackedMatrix * b::OneHotVector = TrackedArray(Tracker.Call(*, a, b)) +a::TrackedMatrix * b::OneHotMatrix = TrackedArray(Tracker.Call(*, a, b)) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 5e26a051..3a64fcb7 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,6 +1,6 @@ module Tracker -export TrackedArray, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, param, back! data(x) = x istracked(x) = false