diff --git a/examples/translation.jl b/examples/translation.jl index df2559cb..89129ad2 100644 --- a/examples/translation.jl +++ b/examples/translation.jl @@ -1,52 +1,50 @@ # Based on https://arxiv.org/abs/1409.0473 using Flux -using Flux: flip +using Flux: flip, stateless, broadcastto, ∘ + +Nbatch = 3 # Number of phrases to batch together +Nphrase = 5 # The length of (padded) phrases +Nalpha = 7 # The size of the token vector +Nhidden = 10 # The size of the hidden state # A recurrent model which takes a token and returns a context-dependent # annotation. -@net type Encoder - forward - backward - token -> hcat(forward(token), backward(token)) -end +forward = LSTM(Nalpha, Nhidden÷2) +backward = flip(LSTM(Nalpha, Nhidden÷2)) +encoder = @net token -> hcat(forward(token), backward(token)) -Encoder(in::Integer, out::Integer) = - Encoder(LSTM(in, out÷2), flip(LSTM(in, out÷2))) +alignnet = Affine(2Nhidden, 1) +align = @net (s, t) -> alignnet(hcat(broadcastto(s, (Nbatch, 1)), t)) # A recurrent model which takes a sequence of annotations, attends, and returns # a predicted output token. -@net type Decoder - attend - recur - state; y; N - function (anns) - energies = map(ann -> exp(attend(hcat(state{-1}, ann))[1]), seq(anns, N)) - weights = energies./sum(energies) - ctx = sum(map((α, ann) -> α .* ann, weights, anns)) - (_, state), y = recur((state{-1},y{-1}), ctx) - y - end +recur = unroll1(LSTM(Nhidden, Nhidden)).model +state = param(zeros(1, Nhidden)) +y = param(zeros(1, Nhidden)) +toalpha = Affine(Nhidden, Nalpha) + +decoder = @net function (tokens) + energies = map(token -> exp.(align(state{-1}, token)), tokens) + weights = map(e -> e ./ sum(energies), energies) + context = sum(map(∘, weights, tokens)) + (y, state), _ = recur((y{-1},state{-1}), context) + return softmax(toalpha(y)) end -Decoder(in::Integer, out::Integer; N = 1) = - Decoder(Affine(in+out, 1), - unroll1(LSTM(in, out)), - param(zeros(1, out)), param(zeros(1, out)), N) +# Building the full model -# The model +a, b = rand(Nbatch, Nalpha), rand(Nbatch, Nalpha) -Nalpha = 5 # The size of the input token vector -Nphrase = 7 # The length of (padded) phrases -Nhidden = 12 # The size of the hidden state +model = @Chain( + stateless(unroll(encoder, Nphrase)), + @net(x -> repeated(x, Nphrase)), + stateless(unroll(decoder, Nphrase))) -encode = Encoder(Nalpha, Nhidden) -decode = Chain(Decoder(Nhidden, Nhidden, N = Nphrase), Affine(Nhidden, Nalpha), softmax) +model = mxnet(Flux.SeqModel(model, Nphrase)) -model = Chain( - unroll(encode, Nphrase, stateful = false), - unroll(decode, Nphrase, stateful = false, seq = false)) +xs = Batch(Seq(rand(Nalpha) for i = 1:Nphrase) for i = 1:Nbatch) -xs = Batch([Seq(rand(Float32, Nalpha) for _ = 1:Nphrase)]) +model(xs)