update translation model

This commit is contained in:
Mike J Innes 2017-06-03 15:52:34 +01:00
parent 10abb64f4b
commit 60b4e1c41c

View File

@ -1,52 +1,50 @@
# Based on https://arxiv.org/abs/1409.0473 # Based on https://arxiv.org/abs/1409.0473
using Flux using Flux
using Flux: flip using Flux: flip, stateless, broadcastto,
Nbatch = 3 # Number of phrases to batch together
Nphrase = 5 # The length of (padded) phrases
Nalpha = 7 # The size of the token vector
Nhidden = 10 # The size of the hidden state
# A recurrent model which takes a token and returns a context-dependent # A recurrent model which takes a token and returns a context-dependent
# annotation. # annotation.
@net type Encoder forward = LSTM(Nalpha, Nhidden÷2)
forward backward = flip(LSTM(Nalpha, Nhidden÷2))
backward encoder = @net token -> hcat(forward(token), backward(token))
token -> hcat(forward(token), backward(token))
end
Encoder(in::Integer, out::Integer) = alignnet = Affine(2Nhidden, 1)
Encoder(LSTM(in, out÷2), flip(LSTM(in, out÷2))) align = @net (s, t) -> alignnet(hcat(broadcastto(s, (Nbatch, 1)), t))
# A recurrent model which takes a sequence of annotations, attends, and returns # A recurrent model which takes a sequence of annotations, attends, and returns
# a predicted output token. # a predicted output token.
@net type Decoder recur = unroll1(LSTM(Nhidden, Nhidden)).model
attend state = param(zeros(1, Nhidden))
recur y = param(zeros(1, Nhidden))
state; y; N toalpha = Affine(Nhidden, Nalpha)
function (anns)
energies = map(ann -> exp(attend(hcat(state{-1}, ann))[1]), seq(anns, N)) decoder = @net function (tokens)
weights = energies./sum(energies) energies = map(token -> exp.(align(state{-1}, token)), tokens)
ctx = sum(map((α, ann) -> α .* ann, weights, anns)) weights = map(e -> e ./ sum(energies), energies)
(_, state), y = recur((state{-1},y{-1}), ctx) context = sum(map(, weights, tokens))
y (y, state), _ = recur((y{-1},state{-1}), context)
end return softmax(toalpha(y))
end end
Decoder(in::Integer, out::Integer; N = 1) = # Building the full model
Decoder(Affine(in+out, 1),
unroll1(LSTM(in, out)),
param(zeros(1, out)), param(zeros(1, out)), N)
# The model a, b = rand(Nbatch, Nalpha), rand(Nbatch, Nalpha)
Nalpha = 5 # The size of the input token vector model = @Chain(
Nphrase = 7 # The length of (padded) phrases stateless(unroll(encoder, Nphrase)),
Nhidden = 12 # The size of the hidden state @net(x -> repeated(x, Nphrase)),
stateless(unroll(decoder, Nphrase)))
encode = Encoder(Nalpha, Nhidden) model = mxnet(Flux.SeqModel(model, Nphrase))
decode = Chain(Decoder(Nhidden, Nhidden, N = Nphrase), Affine(Nhidden, Nalpha), softmax)
model = Chain( xs = Batch(Seq(rand(Nalpha) for i = 1:Nphrase) for i = 1:Nbatch)
unroll(encode, Nphrase, stateful = false),
unroll(decode, Nphrase, stateful = false, seq = false))
xs = Batch([Seq(rand(Float32, Nalpha) for _ = 1:Nphrase)]) model(xs)