2016-12-15 22:41:29 +00:00
|
|
|
# Based on https://arxiv.org/abs/1409.0473
|
|
|
|
|
|
|
|
using Flux
|
2017-06-03 14:52:34 +00:00
|
|
|
using Flux: flip, stateless, broadcastto, ∘
|
|
|
|
|
|
|
|
Nbatch = 3 # Number of phrases to batch together
|
|
|
|
Nphrase = 5 # The length of (padded) phrases
|
|
|
|
Nalpha = 7 # The size of the token vector
|
|
|
|
Nhidden = 10 # The size of the hidden state
|
2016-12-15 22:41:29 +00:00
|
|
|
|
2016-12-21 13:04:33 +00:00
|
|
|
# A recurrent model which takes a token and returns a context-dependent
|
2016-12-15 22:41:29 +00:00
|
|
|
# annotation.
|
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
forward = LSTM(Nalpha, Nhidden÷2)
|
|
|
|
backward = flip(LSTM(Nalpha, Nhidden÷2))
|
|
|
|
encoder = @net token -> hcat(forward(token), backward(token))
|
2016-12-15 22:41:29 +00:00
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
alignnet = Affine(2Nhidden, 1)
|
|
|
|
align = @net (s, t) -> alignnet(hcat(broadcastto(s, (Nbatch, 1)), t))
|
2016-12-15 22:41:29 +00:00
|
|
|
|
|
|
|
# A recurrent model which takes a sequence of annotations, attends, and returns
|
|
|
|
# a predicted output token.
|
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
recur = unroll1(LSTM(Nhidden, Nhidden)).model
|
|
|
|
state = param(zeros(1, Nhidden))
|
|
|
|
y = param(zeros(1, Nhidden))
|
|
|
|
toalpha = Affine(Nhidden, Nalpha)
|
|
|
|
|
|
|
|
decoder = @net function (tokens)
|
|
|
|
energies = map(token -> exp.(align(state{-1}, token)), tokens)
|
|
|
|
weights = map(e -> e ./ sum(energies), energies)
|
|
|
|
context = sum(map(∘, weights, tokens))
|
|
|
|
(y, state), _ = recur((y{-1},state{-1}), context)
|
|
|
|
return softmax(toalpha(y))
|
2016-12-15 22:41:29 +00:00
|
|
|
end
|
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
# Building the full model
|
2016-12-15 22:41:29 +00:00
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
a, b = rand(Nbatch, Nalpha), rand(Nbatch, Nalpha)
|
2016-12-15 22:41:29 +00:00
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
model = @Chain(
|
|
|
|
stateless(unroll(encoder, Nphrase)),
|
|
|
|
@net(x -> repeated(x, Nphrase)),
|
|
|
|
stateless(unroll(decoder, Nphrase)))
|
2016-12-15 22:41:29 +00:00
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
model = mxnet(Flux.SeqModel(model, Nphrase))
|
2016-12-15 22:41:29 +00:00
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
xs = Batch(Seq(rand(Nalpha) for i = 1:Nphrase) for i = 1:Nbatch)
|
2016-12-15 22:41:29 +00:00
|
|
|
|
2017-06-03 14:52:34 +00:00
|
|
|
model(xs)
|