[Download]

Google Neural Machine Translation

In this notebook, we are going to train Google NMT on IWSLT 2015 English-Vietnamese Dataset. The building process includes four steps: 1) load and process dataset, 2) create sampler and DataLoader, 3) build model, and 4) write training epochs.

Load MXNET and Gluon

In [1]:
import warnings
warnings.filterwarnings('ignore')

import argparse
import time
import random
import os
import io
import logging
import numpy as np
import mxnet as mx
from mxnet import gluon
import gluonnlp as nlp
import nmt

Hyper-parameters

In [2]:
np.random.seed(100)
random.seed(100)
mx.random.seed(10000)
ctx = mx.gpu(0)

# parameters for dataset
dataset = 'IWSLT2015'
src_lang, tgt_lang = 'en', 'vi'
src_max_len, tgt_max_len = 50, 50

# parameters for model
num_hidden = 512
num_layers = 2
num_bi_layers = 1
dropout = 0.2

# parameters for training
batch_size, test_batch_size = 128, 32
num_buckets = 5
epochs = 1
clip = 5
lr = 0.001
lr_update_factor = 0.5
log_interval = 10
save_dir = 'gnmt_en_vi_u512'

#parameters for testing
beam_size = 10
lp_alpha = 1.0
lp_k = 5

nmt.utils.logging_config(save_dir)
All Logs will be saved to gnmt_en_vi_u512/<ipython-input-2-4699ac3a1bfb>.log
Out[2]:
'gnmt_en_vi_u512'

Load and Preprocess Dataset

The following shows how to process the dataset and cache the processed dataset for future use. The processing steps include: 1) clip the source and target sequences, 2) split the string input to a list of tokens, 3) map the string token into its integer index in the vocabulary, and 4) append end-of-sentence (EOS) token to source sentence and add BOS and EOS tokens to target sentence.

In [3]:
def cache_dataset(dataset, prefix):
    """Cache the processed npy dataset  the dataset into a npz

    Parameters
    ----------
    dataset : gluon.data.SimpleDataset
    file_path : str
    """
    if not os.path.exists(nmt._constants.CACHE_PATH):
        os.makedirs(nmt._constants.CACHE_PATH)
    src_data = np.concatenate([e[0] for e in dataset])
    tgt_data = np.concatenate([e[1] for e in dataset])
    src_cumlen = np.cumsum([0]+[len(e[0]) for e in dataset])
    tgt_cumlen = np.cumsum([0]+[len(e[1]) for e in dataset])
    np.savez(os.path.join(nmt._constants.CACHE_PATH, prefix + '.npz'),
             src_data=src_data, tgt_data=tgt_data,
             src_cumlen=src_cumlen, tgt_cumlen=tgt_cumlen)


def load_cached_dataset(prefix):
    cached_file_path = os.path.join(nmt._constants.CACHE_PATH, prefix + '.npz')
    if os.path.exists(cached_file_path):
        print('Load cached data from {}'.format(cached_file_path))
        npz_data = np.load(cached_file_path)
        src_data, tgt_data, src_cumlen, tgt_cumlen = [npz_data[n] for n in
                ['src_data', 'tgt_data', 'src_cumlen', 'tgt_cumlen']]
        src_data = np.array([src_data[low:high] for low, high in zip(src_cumlen[:-1], src_cumlen[1:])])
        tgt_data = np.array([tgt_data[low:high] for low, high in zip(tgt_cumlen[:-1], tgt_cumlen[1:])])
        return gluon.data.ArrayDataset(np.array(src_data), np.array(tgt_data))
    else:
        return None


class TrainValDataTransform(object):
    """Transform the machine translation dataset.

    Clip source and the target sentences to the maximum length. For the source sentence, append the
    EOS. For the target sentence, append BOS and EOS.

    Parameters
    ----------
    src_vocab : Vocab
    tgt_vocab : Vocab
    src_max_len : int
    tgt_max_len : int
    """
    def __init__(self, src_vocab, tgt_vocab, src_max_len, tgt_max_len):
        self._src_vocab = src_vocab
        self._tgt_vocab = tgt_vocab
        self._src_max_len = src_max_len
        self._tgt_max_len = tgt_max_len

    def __call__(self, src, tgt):
        if self._src_max_len > 0:
            src_sentence = self._src_vocab[src.split()[:self._src_max_len]]
        else:
            src_sentence = self._src_vocab[src.split()]
        if self._tgt_max_len > 0:
            tgt_sentence = self._tgt_vocab[tgt.split()[:self._tgt_max_len]]
        else:
            tgt_sentence = self._tgt_vocab[tgt.split()]
        src_sentence.append(self._src_vocab[self._src_vocab.eos_token])
        tgt_sentence.insert(0, self._tgt_vocab[self._tgt_vocab.bos_token])
        tgt_sentence.append(self._tgt_vocab[self._tgt_vocab.eos_token])
        src_npy = np.array(src_sentence, dtype=np.int32)
        tgt_npy = np.array(tgt_sentence, dtype=np.int32)
        return src_npy, tgt_npy


def process_dataset(dataset, src_vocab, tgt_vocab, src_max_len=-1, tgt_max_len=-1):
    start = time.time()
    dataset_processed = dataset.transform(TrainValDataTransform(src_vocab, tgt_vocab,
                                                                src_max_len,
                                                                tgt_max_len), lazy=False)
    end = time.time()
    print('Processing time spent: {}'.format(end - start))
    return dataset_processed


def load_translation_data(dataset, src_lang='en', tgt_lang='vi'):
    """Load translation dataset

    Parameters
    ----------
    dataset : str
    src_lang : str, default 'en'
    tgt_lang : str, default 'vi'

    Returns
    -------
    data_train_processed : Dataset
        The preprocessed training sentence pairs
    data_val_processed : Dataset
        The preprocessed validation sentence pairs
    data_test_processed : Dataset
        The preprocessed test sentence pairs
    val_tgt_sentences : list
        The target sentences in the validation set
    test_tgt_sentences : list
        The target sentences in the test set
    src_vocab : Vocab
        Vocabulary of the source language
    tgt_vocab : Vocab
        Vocabulary of the target language
    """
    common_prefix = 'IWSLT2015_{}_{}_{}_{}'.format(src_lang, tgt_lang,
                                                   src_max_len, tgt_max_len)
    data_train = nlp.data.IWSLT2015('train', src_lang=src_lang, tgt_lang=tgt_lang)
    data_val = nlp.data.IWSLT2015('val', src_lang=src_lang, tgt_lang=tgt_lang)
    data_test = nlp.data.IWSLT2015('test', src_lang=src_lang, tgt_lang=tgt_lang)
    src_vocab, tgt_vocab = data_train.src_vocab, data_train.tgt_vocab
    data_train_processed = load_cached_dataset(common_prefix + '_train')
    if not data_train_processed:
        data_train_processed = process_dataset(data_train, src_vocab, tgt_vocab,
                                               src_max_len, tgt_max_len)
        cache_dataset(data_train_processed, common_prefix + '_train')
    data_val_processed = load_cached_dataset(common_prefix + '_val')
    if not data_val_processed:
        data_val_processed = process_dataset(data_val, src_vocab, tgt_vocab)
        cache_dataset(data_val_processed, common_prefix + '_val')
    data_test_processed = load_cached_dataset(common_prefix + '_test')
    if not data_test_processed:
        data_test_processed = process_dataset(data_test, src_vocab, tgt_vocab)
        cache_dataset(data_test_processed, common_prefix + '_test')
    fetch_tgt_sentence = lambda src, tgt: tgt.split()
    val_tgt_sentences = list(data_val.transform(fetch_tgt_sentence))
    test_tgt_sentences = list(data_test.transform(fetch_tgt_sentence))
    return data_train_processed, data_val_processed, data_test_processed, \
           val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab


def get_data_lengths(dataset):
    return list(dataset.transform(lambda srg, tgt: (len(srg), len(tgt))))


data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab\
    = load_translation_data(dataset=dataset, src_lang=src_lang, tgt_lang=tgt_lang)
data_train_lengths = get_data_lengths(data_train)
data_val_lengths = get_data_lengths(data_val)
data_test_lengths = get_data_lengths(data_test)

with io.open(os.path.join(save_dir, 'val_gt.txt'), 'w', encoding='utf-8') as of:
    for ele in val_tgt_sentences:
        of.write(' '.join(ele) + '\n')

with io.open(os.path.join(save_dir, 'test_gt.txt'), 'w', encoding='utf-8') as of:
    for ele in test_tgt_sentences:
        of.write(' '.join(ele) + '\n')


data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False)
data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i)
                                     for i, ele in enumerate(data_val)])
data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i)
                                      for i, ele in enumerate(data_test)])
Processing time spent: 15.472707033157349
Processing time spent: 0.28656888008117676
Processing time spent: 0.31331896781921387

Create Sampler and DataLoader

Now, we have obtained data_train, data_val, and data_test. The next step is to construct sampler and DataLoader. The first step is to construct batchify function, which pads and stacks sequences to form mini-batch.

In [4]:
train_batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(),
                                            nlp.data.batchify.Pad(),
                                            nlp.data.batchify.Stack(dtype='float32'),
                                            nlp.data.batchify.Stack(dtype='float32'))
test_batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(),
                                           nlp.data.batchify.Pad(),
                                           nlp.data.batchify.Stack(dtype='float32'),
                                           nlp.data.batchify.Stack(dtype='float32'),
                                           nlp.data.batchify.Stack())

We can then construct bucketing samplers, which generate batches by grouping sequences with similar lengths. Here, the bucketing scheme is empirically determined.

In [5]:
bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
train_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_train_lengths,
                                                  batch_size=batch_size,
                                                  num_buckets=num_buckets,
                                                  shuffle=True,
                                                  bucket_scheme=bucket_scheme)
logging.info('Train Batch Sampler:\n{}'.format(train_batch_sampler.stats()))
val_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_val_lengths,
                                                batch_size=test_batch_size,
                                                num_buckets=num_buckets,
                                                shuffle=False)
logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats()))
test_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_test_lengths,
                                                 batch_size=test_batch_size,
                                                 num_buckets=num_buckets,
                                                 shuffle=False)
logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats()))
2019-04-14 08:39:26,233 - root - Train Batch Sampler:
FixedBucketSampler:
  sample_num=133166, batch_num=1043
  key=[(9, 10), (16, 17), (26, 27), (37, 38), (51, 52)]
  cnt=[11414, 34897, 37760, 23480, 25615]
  batch_size=[128, 128, 128, 128, 128]
2019-04-14 08:39:26,247 - root - Valid Batch Sampler:
FixedBucketSampler:
  sample_num=1553, batch_num=52
  key=[(22, 28), (40, 52), (58, 76), (76, 100), (94, 124)]
  cnt=[1037, 432, 67, 10, 7]
  batch_size=[32, 32, 32, 32, 32]
2019-04-14 08:39:26,260 - root - Test Batch Sampler:
FixedBucketSampler:
  sample_num=1268, batch_num=42
  key=[(23, 29), (43, 53), (63, 77), (83, 101), (103, 125)]
  cnt=[770, 381, 84, 26, 7]
  batch_size=[32, 32, 32, 32, 32]

Given the samplers, we can create DataLoader, which is iterable.

In [6]:
train_data_loader = gluon.data.DataLoader(data_train,
                                          batch_sampler=train_batch_sampler,
                                          batchify_fn=train_batchify_fn,
                                          num_workers=4)
val_data_loader = gluon.data.DataLoader(data_val,
                                        batch_sampler=val_batch_sampler,
                                        batchify_fn=test_batchify_fn,
                                        num_workers=4)
test_data_loader = gluon.data.DataLoader(data_test,
                                         batch_sampler=test_batch_sampler,
                                         batchify_fn=test_batchify_fn,
                                         num_workers=4)

Build GNMT Model

After obtaining DataLoader, we can build the model. The GNMT encoder and decoder can be easily constructed by calling get_gnmt_encoder_decoder function. Then, we feed the encoder and decoder to NMTModel to construct the GNMT model. model.hybridize allows computation to be done using the symbolic backend.

In [7]:
encoder, decoder = nmt.gnmt.get_gnmt_encoder_decoder(hidden_size=num_hidden,
                                                     dropout=dropout,
                                                     num_layers=num_layers,
                                                     num_bi_layers=num_bi_layers)
model = nlp.model.translation.NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder,
                                       decoder=decoder, embed_size=num_hidden, prefix='gnmt_')
model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
logging.info(model)

# Due to the paddings, we need to mask out the losses corresponding to padding tokens.
loss_function = nmt.loss.SoftmaxCEMaskedLoss()
loss_function.hybridize(static_alloc=static_alloc)
2019-04-14 08:39:45,410 - root - NMTModel(
  (encoder): GNMTEncoder(
    (dropout_layer): Dropout(p = 0.2, axes=())
    (rnn_cells): HybridSequential(
      (0): BidirectionalCell(forward=LSTMCell(None -> 2048), backward=LSTMCell(None -> 2048))
      (1): LSTMCell(None -> 2048)
    )
  )
  (decoder): GNMTDecoder(
    (attention_cell): DotProductAttentionCell(
      (_dropout_layer): Dropout(p = 0.0, axes=())
      (_proj_query): Dense(None -> 512, linear)
    )
    (dropout_layer): Dropout(p = 0.2, axes=())
    (rnn_cells): HybridSequential(
      (0): LSTMCell(None -> 2048)
      (1): LSTMCell(None -> 2048)
    )
  )
  (src_embed): HybridSequential(
    (0): Embedding(17191 -> 512, float32)
    (1): Dropout(p = 0.0, axes=())
  )
  (tgt_embed): HybridSequential(
    (0): Embedding(7709 -> 512, float32)
    (1): Dropout(p = 0.0, axes=())
  )
  (tgt_proj): Dense(None -> 7709, linear)
)

We also build the beam search translator.

In [8]:
translator = nmt.translation.BeamSearchTranslator(model=model, beam_size=beam_size,
                                                  scorer=nlp.model.BeamSearchScorer(alpha=lp_alpha,
                                                                                    K=lp_k),
                                                  max_length=tgt_max_len + 100)
logging.info('Use beam_size={}, alpha={}, K={}'.format(beam_size, lp_alpha, lp_k))
2019-04-14 08:39:45,683 - root - Use beam_size=10, alpha=1.0, K=5

We define evaluation function as follows. The evaluate function use beam search translator to generate outputs for the validation and testing datasets.

In [9]:
def evaluate(data_loader):
    """Evaluate given the data loader

    Parameters
    ----------
    data_loader : gluon.data.DataLoader

    Returns
    -------
    avg_loss : float
        Average loss
    real_translation_out : list of list of str
        The translation output
    """
    translation_out = []
    all_inst_ids = []
    avg_loss_denom = 0
    avg_loss = 0.0
    for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \
            in enumerate(data_loader):
        src_seq = src_seq.as_in_context(ctx)
        tgt_seq = tgt_seq.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx)
        tgt_valid_length = tgt_valid_length.as_in_context(ctx)
        # Calculating Loss
        out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
        loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar()
        all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist())
        avg_loss += loss * (tgt_seq.shape[1] - 1)
        avg_loss_denom += (tgt_seq.shape[1] - 1)
        # Translate
        samples, _, sample_valid_length =\
            translator.translate(src_seq=src_seq, src_valid_length=src_valid_length)
        max_score_sample = samples[:, 0, :].asnumpy()
        sample_valid_length = sample_valid_length[:, 0].asnumpy()
        for i in range(max_score_sample.shape[0]):
            translation_out.append(
                [tgt_vocab.idx_to_token[ele] for ele in
                 max_score_sample[i][1:(sample_valid_length[i] - 1)]])
    avg_loss = avg_loss / avg_loss_denom
    real_translation_out = [None for _ in range(len(all_inst_ids))]
    for ind, sentence in zip(all_inst_ids, translation_out):
        real_translation_out[ind] = sentence
    return avg_loss, real_translation_out


def write_sentences(sentences, file_path):
    with io.open(file_path, 'w', encoding='utf-8') as of:
        for sent in sentences:
            of.write(' '.join(sent) + '\n')

Training Epochs

Before entering the training stage, we need to create trainer for updating the parameters. In the following example, we create a trainer that uses ADAM optimzier.

In [10]:
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': lr})

We can then write the training loop. During the training, we evaluate on the validation and testing datasets every epoch, and record the parameters that give the hightest BLEU score on the validation dataset. Before performing forward and backward, we first use as_in_context function to copy the mini-batch to GPU. The statement with mx.autograd.record() tells Gluon backend to compute the gradients for the part inside the block.

In [11]:
best_valid_bleu = 0.0
for epoch_id in range(epochs):
    log_avg_loss = 0
    log_avg_gnorm = 0
    log_wc = 0
    log_start_time = time.time()
    for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
            in enumerate(train_data_loader):
        # logging.info(src_seq.context) Context suddenly becomes GPU.
        src_seq = src_seq.as_in_context(ctx)
        tgt_seq = tgt_seq.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx)
        tgt_valid_length = tgt_valid_length.as_in_context(ctx)
        with mx.autograd.record():
            out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
            loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean()
            loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean()
            loss.backward()
        grads = [p.grad(ctx) for p in model.collect_params().values()]
        gnorm = gluon.utils.clip_global_norm(grads, clip)
        trainer.step(1)
        src_wc = src_valid_length.sum().asscalar()
        tgt_wc = (tgt_valid_length - 1).sum().asscalar()
        step_loss = loss.asscalar()
        log_avg_loss += step_loss
        log_avg_gnorm += gnorm
        log_wc += src_wc + tgt_wc
        if (batch_id + 1) % log_interval == 0:
            wps = log_wc / (time.time() - log_start_time)
            logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                         'throughput={:.2f}K wps, wc={:.2f}K'
                         .format(epoch_id, batch_id + 1, len(train_data_loader),
                                 log_avg_loss / log_interval,
                                 np.exp(log_avg_loss / log_interval),
                                 log_avg_gnorm / log_interval,
                                 wps / 1000, log_wc / 1000))
            log_start_time = time.time()
            log_avg_loss = 0
            log_avg_gnorm = 0
            log_wc = 0
    valid_loss, valid_translation_out = evaluate(val_data_loader)
    valid_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu([val_tgt_sentences], valid_translation_out)
    logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'
                 .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100))
    test_loss, test_translation_out = evaluate(test_data_loader)
    test_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu([test_tgt_sentences], test_translation_out)
    logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'
                 .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100))
    write_sentences(valid_translation_out,
                    os.path.join(save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id))
    write_sentences(test_translation_out,
                    os.path.join(save_dir, 'epoch{:d}_test_out.txt').format(epoch_id))
    if valid_bleu_score > best_valid_bleu:
        best_valid_bleu = valid_bleu_score
        save_path = os.path.join(save_dir, 'valid_best.params')
        logging.info('Save best parameters to {}'.format(save_path))
        model.save_parameters(save_path)
    if epoch_id + 1 >= (epochs * 2) // 3:
        new_lr = trainer.learning_rate * lr_update_factor
        logging.info('Learning rate change to {}'.format(new_lr))
        trainer.set_learning_rate(new_lr)
2019-04-14 08:39:59,061 - root - [Epoch 0 Batch 10/1043] loss=7.7375, ppl=2292.6586, gnorm=1.4907, throughput=4.13K wps, wc=54.27K
2019-04-14 08:40:06,626 - root - [Epoch 0 Batch 20/1043] loss=6.3590, ppl=577.6408, gnorm=1.5744, throughput=6.64K wps, wc=50.20K
2019-04-14 08:40:16,050 - root - [Epoch 0 Batch 30/1043] loss=6.3708, ppl=584.5344, gnorm=0.8043, throughput=7.20K wps, wc=67.78K
2019-04-14 08:40:24,928 - root - [Epoch 0 Batch 40/1043] loss=6.1791, ppl=482.5550, gnorm=0.6213, throughput=7.12K wps, wc=63.19K
2019-04-14 08:40:32,580 - root - [Epoch 0 Batch 50/1043] loss=6.1872, ppl=486.4970, gnorm=0.3987, throughput=8.10K wps, wc=61.93K
2019-04-14 08:40:38,539 - root - [Epoch 0 Batch 60/1043] loss=6.1053, ppl=448.2160, gnorm=0.6782, throughput=9.94K wps, wc=59.19K
2019-04-14 08:40:46,593 - root - [Epoch 0 Batch 70/1043] loss=6.1555, ppl=471.2990, gnorm=0.4628, throughput=9.07K wps, wc=72.99K
2019-04-14 08:40:53,537 - root - [Epoch 0 Batch 80/1043] loss=6.0697, ppl=432.5624, gnorm=0.4173, throughput=9.30K wps, wc=64.58K
2019-04-14 08:40:59,599 - root - [Epoch 0 Batch 90/1043] loss=5.9385, ppl=379.3764, gnorm=0.3605, throughput=8.75K wps, wc=53.02K
2019-04-14 08:41:06,958 - root - [Epoch 0 Batch 100/1043] loss=5.8757, ppl=356.2836, gnorm=0.3960, throughput=8.08K wps, wc=59.42K
2019-04-14 08:41:14,255 - root - [Epoch 0 Batch 110/1043] loss=5.8713, ppl=354.7196, gnorm=0.3559, throughput=8.98K wps, wc=65.50K
2019-04-14 08:41:19,899 - root - [Epoch 0 Batch 120/1043] loss=5.8693, ppl=353.9884, gnorm=0.3342, throughput=10.36K wps, wc=58.43K
2019-04-14 08:41:26,501 - root - [Epoch 0 Batch 130/1043] loss=5.9360, ppl=378.4299, gnorm=0.3630, throughput=9.00K wps, wc=59.39K
2019-04-14 08:41:32,975 - root - [Epoch 0 Batch 140/1043] loss=5.8838, ppl=359.1623, gnorm=0.2900, throughput=9.45K wps, wc=61.18K
2019-04-14 08:41:38,776 - root - [Epoch 0 Batch 150/1043] loss=5.8180, ppl=336.2913, gnorm=0.2885, throughput=9.72K wps, wc=56.34K
2019-04-14 08:41:45,483 - root - [Epoch 0 Batch 160/1043] loss=5.7480, ppl=313.5724, gnorm=0.3821, throughput=8.64K wps, wc=57.93K
2019-04-14 08:41:50,552 - root - [Epoch 0 Batch 170/1043] loss=5.7697, ppl=320.4299, gnorm=0.2978, throughput=12.72K wps, wc=64.42K
2019-04-14 08:41:53,983 - root - [Epoch 0 Batch 180/1043] loss=5.4645, ppl=236.1659, gnorm=0.3361, throughput=12.92K wps, wc=44.31K
2019-04-14 08:41:58,815 - root - [Epoch 0 Batch 190/1043] loss=5.6188, ppl=275.5575, gnorm=0.3574, throughput=12.93K wps, wc=62.45K
2019-04-14 08:42:03,419 - root - [Epoch 0 Batch 200/1043] loss=5.5120, ppl=247.6430, gnorm=0.3602, throughput=11.79K wps, wc=54.24K
2019-04-14 08:42:07,823 - root - [Epoch 0 Batch 210/1043] loss=5.3315, ppl=206.7393, gnorm=0.4307, throughput=11.97K wps, wc=52.68K
2019-04-14 08:42:11,373 - root - [Epoch 0 Batch 220/1043] loss=5.3136, ppl=203.0871, gnorm=0.3820, throughput=14.23K wps, wc=50.47K
2019-04-14 08:42:15,877 - root - [Epoch 0 Batch 230/1043] loss=5.3114, ppl=202.6384, gnorm=0.4176, throughput=13.70K wps, wc=61.65K
2019-04-14 08:42:18,859 - root - [Epoch 0 Batch 240/1043] loss=5.3077, ppl=201.8900, gnorm=0.3525, throughput=19.58K wps, wc=58.32K
2019-04-14 08:42:23,266 - root - [Epoch 0 Batch 250/1043] loss=5.4239, ppl=226.7514, gnorm=0.2877, throughput=16.08K wps, wc=70.84K
2019-04-14 08:42:27,200 - root - [Epoch 0 Batch 260/1043] loss=5.2793, ppl=196.2366, gnorm=0.3558, throughput=15.32K wps, wc=60.22K
2019-04-14 08:42:31,911 - root - [Epoch 0 Batch 270/1043] loss=5.3794, ppl=216.8982, gnorm=0.2814, throughput=15.50K wps, wc=72.96K
2019-04-14 08:42:35,196 - root - [Epoch 0 Batch 280/1043] loss=5.2551, ppl=191.5379, gnorm=0.2626, throughput=18.53K wps, wc=60.80K
2019-04-14 08:42:37,818 - root - [Epoch 0 Batch 290/1043] loss=4.9794, ppl=145.3931, gnorm=0.3467, throughput=17.48K wps, wc=45.79K
2019-04-14 08:42:40,906 - root - [Epoch 0 Batch 300/1043] loss=5.0895, ppl=162.3008, gnorm=0.3472, throughput=19.24K wps, wc=59.05K
2019-04-14 08:42:44,789 - root - [Epoch 0 Batch 310/1043] loss=5.1061, ppl=165.0174, gnorm=0.2920, throughput=15.86K wps, wc=61.58K
2019-04-14 08:42:48,284 - root - [Epoch 0 Batch 320/1043] loss=4.9833, ppl=145.9550, gnorm=0.3093, throughput=15.22K wps, wc=53.10K
2019-04-14 08:42:52,096 - root - [Epoch 0 Batch 330/1043] loss=5.0395, ppl=154.3902, gnorm=0.3038, throughput=16.11K wps, wc=61.37K
2019-04-14 08:42:55,745 - root - [Epoch 0 Batch 340/1043] loss=5.0647, ppl=158.3260, gnorm=0.2648, throughput=15.60K wps, wc=56.88K
2019-04-14 08:42:59,604 - root - [Epoch 0 Batch 350/1043] loss=4.9178, ppl=136.7084, gnorm=0.3157, throughput=14.23K wps, wc=54.86K
2019-04-14 08:43:03,594 - root - [Epoch 0 Batch 360/1043] loss=5.0448, ppl=155.2102, gnorm=0.2936, throughput=16.19K wps, wc=64.55K
2019-04-14 08:43:07,564 - root - [Epoch 0 Batch 370/1043] loss=4.9053, ppl=135.0092, gnorm=0.3128, throughput=16.89K wps, wc=66.97K
2019-04-14 08:43:11,111 - root - [Epoch 0 Batch 380/1043] loss=4.8187, ppl=123.8092, gnorm=0.3114, throughput=14.90K wps, wc=52.79K
2019-04-14 08:43:14,499 - root - [Epoch 0 Batch 390/1043] loss=4.7931, ppl=120.6737, gnorm=0.3433, throughput=15.05K wps, wc=50.94K
2019-04-14 08:43:18,219 - root - [Epoch 0 Batch 400/1043] loss=4.6418, ppl=103.7261, gnorm=0.3465, throughput=12.98K wps, wc=48.22K
2019-04-14 08:43:23,286 - root - [Epoch 0 Batch 410/1043] loss=4.8132, ppl=123.1204, gnorm=0.3010, throughput=9.54K wps, wc=48.27K
2019-04-14 08:43:28,576 - root - [Epoch 0 Batch 420/1043] loss=4.8389, ppl=126.3339, gnorm=0.3056, throughput=10.62K wps, wc=56.14K
2019-04-14 08:43:35,005 - root - [Epoch 0 Batch 430/1043] loss=4.8893, ppl=132.8639, gnorm=0.3135, throughput=10.80K wps, wc=69.33K
2019-04-14 08:43:40,311 - root - [Epoch 0 Batch 440/1043] loss=4.7815, ppl=119.2835, gnorm=0.2937, throughput=12.67K wps, wc=67.08K
2019-04-14 08:43:45,003 - root - [Epoch 0 Batch 450/1043] loss=4.6943, ppl=109.3172, gnorm=0.3366, throughput=11.44K wps, wc=53.68K
2019-04-14 08:43:49,555 - root - [Epoch 0 Batch 460/1043] loss=4.4797, ppl=88.2107, gnorm=0.3680, throughput=11.08K wps, wc=50.38K
2019-04-14 08:43:54,337 - root - [Epoch 0 Batch 470/1043] loss=4.7576, ppl=116.4623, gnorm=0.3262, throughput=12.70K wps, wc=60.70K
2019-04-14 08:43:58,731 - root - [Epoch 0 Batch 480/1043] loss=4.3964, ppl=81.1588, gnorm=0.3478, throughput=12.31K wps, wc=54.04K
2019-04-14 08:44:02,540 - root - [Epoch 0 Batch 490/1043] loss=4.5168, ppl=91.5466, gnorm=0.4307, throughput=12.17K wps, wc=46.32K
2019-04-14 08:44:06,741 - root - [Epoch 0 Batch 500/1043] loss=4.5597, ppl=95.5510, gnorm=0.3093, throughput=11.57K wps, wc=48.55K
2019-04-14 08:44:10,618 - root - [Epoch 0 Batch 510/1043] loss=4.3226, ppl=75.3857, gnorm=0.3331, throughput=10.75K wps, wc=41.62K
2019-04-14 08:44:15,654 - root - [Epoch 0 Batch 520/1043] loss=4.2216, ppl=68.1392, gnorm=0.3600, throughput=7.94K wps, wc=39.98K
2019-04-14 08:44:21,093 - root - [Epoch 0 Batch 530/1043] loss=4.6480, ppl=104.3750, gnorm=0.2983, throughput=10.78K wps, wc=58.58K
2019-04-14 08:44:25,247 - root - [Epoch 0 Batch 540/1043] loss=4.5112, ppl=91.0268, gnorm=0.3089, throughput=12.22K wps, wc=50.72K
2019-04-14 08:44:29,900 - root - [Epoch 0 Batch 550/1043] loss=4.5359, ppl=93.3104, gnorm=0.3199, throughput=13.54K wps, wc=62.95K
2019-04-14 08:44:33,581 - root - [Epoch 0 Batch 560/1043] loss=4.3236, ppl=75.4583, gnorm=0.3499, throughput=12.54K wps, wc=46.13K
2019-04-14 08:44:38,997 - root - [Epoch 0 Batch 570/1043] loss=4.3446, ppl=77.0582, gnorm=0.3143, throughput=11.29K wps, wc=61.12K
2019-04-14 08:44:44,196 - root - [Epoch 0 Batch 580/1043] loss=4.3568, ppl=78.0038, gnorm=0.3161, throughput=10.67K wps, wc=55.43K
2019-04-14 08:44:49,832 - root - [Epoch 0 Batch 590/1043] loss=4.5871, ppl=98.2086, gnorm=0.2637, throughput=13.12K wps, wc=73.93K
2019-04-14 08:44:53,759 - root - [Epoch 0 Batch 600/1043] loss=4.4915, ppl=89.2552, gnorm=0.2708, throughput=14.22K wps, wc=55.80K
2019-04-14 08:44:58,433 - root - [Epoch 0 Batch 610/1043] loss=4.3116, ppl=74.5632, gnorm=0.3215, throughput=11.19K wps, wc=52.28K
2019-04-14 08:45:04,528 - root - [Epoch 0 Batch 620/1043] loss=4.5561, ppl=95.2079, gnorm=0.2642, throughput=11.88K wps, wc=72.39K
2019-04-14 08:45:07,199 - root - [Epoch 0 Batch 630/1043] loss=4.0478, ppl=57.2698, gnorm=0.3483, throughput=12.91K wps, wc=34.44K
2019-04-14 08:45:11,705 - root - [Epoch 0 Batch 640/1043] loss=4.3370, ppl=76.4775, gnorm=0.3257, throughput=12.72K wps, wc=57.32K
2019-04-14 08:45:17,065 - root - [Epoch 0 Batch 650/1043] loss=4.4735, ppl=87.6636, gnorm=0.2794, throughput=12.40K wps, wc=66.42K
2019-04-14 08:45:21,140 - root - [Epoch 0 Batch 660/1043] loss=4.1894, ppl=65.9808, gnorm=0.3359, throughput=10.91K wps, wc=44.42K
2019-04-14 08:45:27,378 - root - [Epoch 0 Batch 670/1043] loss=4.5967, ppl=99.1571, gnorm=0.2508, throughput=12.46K wps, wc=77.68K
2019-04-14 08:45:31,644 - root - [Epoch 0 Batch 680/1043] loss=4.4215, ppl=83.2244, gnorm=0.2815, throughput=13.85K wps, wc=59.02K
2019-04-14 08:45:35,671 - root - [Epoch 0 Batch 690/1043] loss=4.2071, ppl=67.1633, gnorm=0.3269, throughput=12.56K wps, wc=50.54K
2019-04-14 08:45:41,317 - root - [Epoch 0 Batch 700/1043] loss=4.2906, ppl=73.0101, gnorm=0.3036, throughput=9.30K wps, wc=52.45K
2019-04-14 08:45:46,884 - root - [Epoch 0 Batch 710/1043] loss=4.1826, ppl=65.5381, gnorm=0.3347, throughput=7.43K wps, wc=41.32K
2019-04-14 08:45:53,375 - root - [Epoch 0 Batch 720/1043] loss=4.2664, ppl=71.2630, gnorm=0.2984, throughput=7.72K wps, wc=50.09K
2019-04-14 08:45:59,465 - root - [Epoch 0 Batch 730/1043] loss=4.1932, ppl=66.2351, gnorm=0.3216, throughput=8.55K wps, wc=52.02K
2019-04-14 08:46:04,228 - root - [Epoch 0 Batch 740/1043] loss=4.3315, ppl=76.0582, gnorm=0.2732, throughput=12.57K wps, wc=59.86K
2019-04-14 08:46:09,161 - root - [Epoch 0 Batch 750/1043] loss=4.1921, ppl=66.1583, gnorm=0.2890, throughput=10.82K wps, wc=53.35K
2019-04-14 08:46:14,764 - root - [Epoch 0 Batch 760/1043] loss=4.2894, ppl=72.9236, gnorm=0.2832, throughput=12.52K wps, wc=70.09K
2019-04-14 08:46:18,593 - root - [Epoch 0 Batch 770/1043] loss=4.1139, ppl=61.1873, gnorm=0.3146, throughput=11.29K wps, wc=43.21K
2019-04-14 08:46:24,203 - root - [Epoch 0 Batch 780/1043] loss=4.3072, ppl=74.2328, gnorm=0.2694, throughput=13.54K wps, wc=75.93K
2019-04-14 08:46:27,903 - root - [Epoch 0 Batch 790/1043] loss=4.1360, ppl=62.5520, gnorm=0.3072, throughput=12.66K wps, wc=46.81K
2019-04-14 08:46:33,675 - root - [Epoch 0 Batch 800/1043] loss=4.2081, ppl=67.2282, gnorm=0.3197, throughput=10.18K wps, wc=58.72K
2019-04-14 08:46:39,898 - root - [Epoch 0 Batch 810/1043] loss=4.0845, ppl=59.4102, gnorm=0.2990, throughput=9.23K wps, wc=57.38K
2019-04-14 08:46:44,532 - root - [Epoch 0 Batch 820/1043] loss=3.9642, ppl=52.6782, gnorm=0.3157, throughput=12.64K wps, wc=58.52K
2019-04-14 08:46:48,792 - root - [Epoch 0 Batch 830/1043] loss=4.1177, ppl=61.4179, gnorm=0.3435, throughput=13.46K wps, wc=57.24K
2019-04-14 08:46:53,451 - root - [Epoch 0 Batch 840/1043] loss=4.0863, ppl=59.5169, gnorm=0.3151, throughput=11.30K wps, wc=52.57K
2019-04-14 08:46:59,497 - root - [Epoch 0 Batch 850/1043] loss=4.1330, ppl=62.3636, gnorm=0.2906, throughput=10.61K wps, wc=64.14K
2019-04-14 08:47:05,246 - root - [Epoch 0 Batch 860/1043] loss=4.1194, ppl=61.5244, gnorm=0.2869, throughput=9.51K wps, wc=54.64K
2019-04-14 08:47:11,183 - root - [Epoch 0 Batch 870/1043] loss=4.1577, ppl=63.9274, gnorm=0.3132, throughput=11.10K wps, wc=65.81K
2019-04-14 08:47:16,294 - root - [Epoch 0 Batch 880/1043] loss=4.0882, ppl=59.6344, gnorm=0.3065, throughput=9.84K wps, wc=50.27K
2019-04-14 08:47:22,118 - root - [Epoch 0 Batch 890/1043] loss=4.1563, ppl=63.8375, gnorm=0.2833, throughput=9.64K wps, wc=56.11K
2019-04-14 08:47:27,747 - root - [Epoch 0 Batch 900/1043] loss=4.1765, ppl=65.1365, gnorm=0.2890, throughput=10.83K wps, wc=60.91K
2019-04-14 08:47:32,237 - root - [Epoch 0 Batch 910/1043] loss=3.9814, ppl=53.5921, gnorm=0.3184, throughput=11.51K wps, wc=51.65K
2019-04-14 08:47:37,793 - root - [Epoch 0 Batch 920/1043] loss=4.1608, ppl=64.1245, gnorm=0.2799, throughput=10.90K wps, wc=60.52K
2019-04-14 08:47:42,728 - root - [Epoch 0 Batch 930/1043] loss=3.9635, ppl=52.6412, gnorm=0.2959, throughput=8.82K wps, wc=43.51K
2019-04-14 08:47:47,311 - root - [Epoch 0 Batch 940/1043] loss=3.9273, ppl=50.7686, gnorm=0.3379, throughput=10.85K wps, wc=49.71K
2019-04-14 08:47:53,635 - root - [Epoch 0 Batch 950/1043] loss=4.1982, ppl=66.5675, gnorm=0.2712, throughput=11.22K wps, wc=70.92K
2019-04-14 08:47:59,691 - root - [Epoch 0 Batch 960/1043] loss=4.1457, ppl=63.1598, gnorm=0.2726, throughput=12.25K wps, wc=74.06K
2019-04-14 08:48:03,300 - root - [Epoch 0 Batch 970/1043] loss=3.8590, ppl=47.4186, gnorm=0.3494, throughput=12.21K wps, wc=44.03K
2019-04-14 08:48:08,914 - root - [Epoch 0 Batch 980/1043] loss=4.1647, ppl=64.3718, gnorm=0.2674, throughput=12.96K wps, wc=72.73K
2019-04-14 08:48:13,519 - root - [Epoch 0 Batch 990/1043] loss=4.0820, ppl=59.2643, gnorm=0.2947, throughput=14.21K wps, wc=65.39K
2019-04-14 08:48:18,502 - root - [Epoch 0 Batch 1000/1043] loss=3.9781, ppl=53.4131, gnorm=0.3141, throughput=11.17K wps, wc=55.65K
2019-04-14 08:48:23,691 - root - [Epoch 0 Batch 1010/1043] loss=4.0658, ppl=58.3116, gnorm=0.2911, throughput=12.65K wps, wc=65.61K
2019-04-14 08:48:27,343 - root - [Epoch 0 Batch 1020/1043] loss=3.8071, ppl=45.0195, gnorm=0.3343, throughput=12.53K wps, wc=45.72K
2019-04-14 08:48:31,677 - root - [Epoch 0 Batch 1030/1043] loss=3.9698, ppl=52.9729, gnorm=0.3078, throughput=12.92K wps, wc=55.97K
2019-04-14 08:48:35,959 - root - [Epoch 0 Batch 1040/1043] loss=3.9267, ppl=50.7392, gnorm=0.3543, throughput=12.45K wps, wc=53.27K
2019-04-14 08:49:21,059 - root - [Epoch 0] valid Loss=2.8464, valid ppl=17.2251, valid bleu=3.32
2019-04-14 08:50:00,516 - root - [Epoch 0] test Loss=2.9874, test ppl=19.8333, test bleu=3.14
2019-04-14 08:50:00,539 - root - Save best parameters to gnmt_en_vi_u512/valid_best.params
2019-04-14 08:50:00,904 - root - Learning rate change to 0.0005

Summary

In this notebook, we have shown how to train a GNMT model on IWSLT 2015 English-Vietnamese using Gluon NLP toolkit. The complete training script can be found here. The command to reproduce the result can be seen in the machine translation page.