Download this tutorial

Train your own LSTM based Language Model

Now let’s go through the step-by-step process on how to train your own language model using GluonNLP.

Preparation

We’ll start by taking care of our basic dependencies and setting up our environment.

Firstly, we import the required modules for GluonNLP and the LM.

[1]:
import warnings
warnings.filterwarnings('ignore')

import glob
import time
import math

import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon.utils import download

import gluonnlp as nlp
nlp.utils.check_version('0.7.0')

Then we setup the environment for GluonNLP.

Please note that we should change num_gpus according to how many NVIDIA GPUs are available on the target machine in the following code.

[2]:
num_gpus = 1
context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus else [mx.cpu()]
log_interval = 200

Next we setup the hyperparameters for the LM we are using.

Note that BPTT stands for “back propagation through time,” and LR stands for learning rate. A link to more information on truncated BPTT can be found here.

[3]:
batch_size = 20 * len(context)
lr = 20
epochs = 3
bptt = 35
grad_clip = 0.25

Loading the dataset

Now, we load the dataset, extract the vocabulary, numericalize, and batchify in order to perform truncated BPTT.

[4]:
dataset_name = 'wikitext-2'

# Load the dataset
train_dataset, val_dataset, test_dataset = [
    nlp.data.WikiText2(
        segment=segment, bos=None, eos='<eos>', skip_empty=False)
    for segment in ['train', 'val', 'test']
]

# Extract the vocabulary and numericalize with "Counter"
vocab = nlp.Vocab(
    nlp.data.Counter(train_dataset), padding_token=None, bos_token=None)

# Batchify for BPTT
bptt_batchify = nlp.data.batchify.CorpusBPTTBatchify(
    vocab, bptt, batch_size, last_batch='discard')
train_data, val_data, test_data = [
    bptt_batchify(x) for x in [train_dataset, val_dataset, test_dataset]
]
Downloading /root/.mxnet/datasets/wikitext-2/wikitext-2-v1.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/wikitext-2/wikitext-2-v1.zip...

And then we load the pre-defined language model architecture as so:

[5]:
model_name = 'standard_lstm_lm_200'
model, vocab = nlp.model.get_model(model_name, vocab=vocab, dataset_name=None)
print(model)
print(vocab)

# Initialize the model
model.initialize(mx.init.Xavier(), ctx=context)

# Initialize the trainer and optimizer and specify some hyperparameters
trainer = gluon.Trainer(model.collect_params(), 'sgd', {
    'learning_rate': lr,
    'momentum': 0,
    'wd': 0
})

# Specify the loss function, in this case, cross-entropy with softmax.
loss = gluon.loss.SoftmaxCrossEntropyLoss()
StandardRNN(
  (embedding): HybridSequential(
    (0): Embedding(33278 -> 200, float32)
    (1): Dropout(p = 0.2, axes=())
  )
  (encoder): LSTM(200 -> 200, TNC, num_layers=2, dropout=0.2)
  (decoder): HybridSequential(
    (0): Dense(200 -> 33278, linear)
  )
)
Vocab(size=33278, unk="<unk>", reserved="['<eos>']")

Training the LM

Now that everything is ready, we can start training the model.

We first define a helper function for detaching the gradients on specific states for easier truncated BPTT.

[6]:
def detach(hidden):
    if isinstance(hidden, (tuple, list)):
        hidden = [detach(i) for i in hidden]
    else:
        hidden = hidden.detach()
    return hidden

And then a helper evaluation function.

[7]:
# Note that ctx is short for context
def evaluate(model, data_source, batch_size, ctx):
    total_L = 0.0
    ntotal = 0
    hidden = model.begin_state(
        batch_size=batch_size, func=mx.nd.zeros, ctx=ctx)
    for i, (data, target) in enumerate(data_source):
        data = data.as_in_context(ctx)
        target = target.as_in_context(ctx)
        output, hidden = model(data, hidden)
        hidden = detach(hidden)
        L = loss(output.reshape(-3, -1), target.reshape(-1))
        total_L += mx.nd.sum(L).asscalar()
        ntotal += L.size
    return total_L / ntotal

The main training loop

Our loss function will be the standard cross-entropy loss function used for multi-class classification, applied at each time step to compare the model’s predictions to the true next word in the sequence. We can calculate gradients with respect to our parameters using truncated BPTT. In this case, we’ll back propagate for \(35\) time steps, updating our weights with stochastic gradient descent and a learning rate of \(20\); these correspond to the hyperparameters that we specified earlier in the notebook.

[8]:
# Function for actually training the model
def train(model, train_data, val_data, test_data, epochs, lr):
    best_val = float("Inf")
    start_train_time = time.time()
    parameters = model.collect_params().values()

    for epoch in range(epochs):
        total_L = 0.0
        start_epoch_time = time.time()
        start_log_interval_time = time.time()
        hiddens = [model.begin_state(batch_size//len(context), func=mx.nd.zeros, ctx=ctx)
                   for ctx in context]

        for i, (data, target) in enumerate(train_data):
            data_list = gluon.utils.split_and_load(data, context,
                                                   batch_axis=1, even_split=True)
            target_list = gluon.utils.split_and_load(target, context,
                                                     batch_axis=1, even_split=True)
            hiddens = detach(hiddens)
            L = 0
            Ls = []

            with autograd.record():
                for j, (X, y, h) in enumerate(zip(data_list, target_list, hiddens)):
                    output, h = model(X, h)
                    batch_L = loss(output.reshape(-3, -1), y.reshape(-1,))
                    L = L + batch_L.as_in_context(context[0]) / (len(context) * X.size)
                    Ls.append(batch_L / (len(context) * X.size))
                    hiddens[j] = h
            L.backward()
            grads = [p.grad(x.context) for p in parameters for x in data_list]
            gluon.utils.clip_global_norm(grads, grad_clip)

            trainer.step(1)

            total_L += sum([mx.nd.sum(l).asscalar() for l in Ls])

            if i % log_interval == 0 and i > 0:
                cur_L = total_L / log_interval
                print('[Epoch %d Batch %d/%d] loss %.2f, ppl %.2f, '
                      'throughput %.2f samples/s'%(
                    epoch, i, len(train_data), cur_L, math.exp(cur_L),
                    batch_size * log_interval / (time.time() - start_log_interval_time)))
                total_L = 0.0
                start_log_interval_time = time.time()

        mx.nd.waitall()

        print('[Epoch %d] throughput %.2f samples/s'%(
                    epoch, len(train_data)*batch_size / (time.time() - start_epoch_time)))

        val_L = evaluate(model, val_data, batch_size, context[0])
        print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%(
            epoch, time.time()-start_epoch_time, val_L, math.exp(val_L)))

        if val_L < best_val:
            best_val = val_L
            test_L = evaluate(model, test_data, batch_size, context[0])
            model.save_parameters('{}_{}-{}.params'.format(model_name, dataset_name, epoch))
            print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
        else:
            lr = lr*0.25
            print('Learning rate now %f'%(lr))
            trainer.set_learning_rate(lr)

    print('Total training throughput %.2f samples/s'%(
                            (batch_size * len(train_data) * epochs) /
                            (time.time() - start_train_time)))

We can now actually perform the training

[9]:
train(model, train_data, val_data, test_data, epochs, lr)
[Epoch 0 Batch 200/2983] loss 7.66, ppl 2118.19, throughput 963.12 samples/s
[Epoch 0 Batch 400/2983] loss 6.75, ppl 855.02, throughput 985.42 samples/s
[Epoch 0 Batch 600/2983] loss 6.34, ppl 567.93, throughput 983.61 samples/s
[Epoch 0 Batch 800/2983] loss 6.18, ppl 484.97, throughput 967.04 samples/s
[Epoch 0 Batch 1000/2983] loss 6.04, ppl 420.57, throughput 984.71 samples/s
[Epoch 0 Batch 1200/2983] loss 5.97, ppl 389.89, throughput 979.67 samples/s
[Epoch 0 Batch 1400/2983] loss 5.85, ppl 348.51, throughput 975.21 samples/s
[Epoch 0 Batch 1600/2983] loss 5.86, ppl 351.57, throughput 975.35 samples/s
[Epoch 0 Batch 1800/2983] loss 5.71, ppl 300.86, throughput 975.68 samples/s
[Epoch 0 Batch 2000/2983] loss 5.66, ppl 288.36, throughput 970.52 samples/s
[Epoch 0 Batch 2200/2983] loss 5.57, ppl 261.72, throughput 979.43 samples/s
[Epoch 0 Batch 2400/2983] loss 5.57, ppl 262.95, throughput 959.92 samples/s
[Epoch 0 Batch 2600/2983] loss 5.56, ppl 260.04, throughput 959.70 samples/s
[Epoch 0 Batch 2800/2983] loss 5.46, ppl 234.50, throughput 956.28 samples/s
[Epoch 0] throughput 971.58 samples/s
[Epoch 0] time cost 63.66s, valid loss 5.42, valid ppl 225.34
test loss 5.34, test ppl 207.99
[Epoch 1 Batch 200/2983] loss 5.46, ppl 236.24, throughput 945.48 samples/s
[Epoch 1 Batch 400/2983] loss 5.45, ppl 232.58, throughput 952.36 samples/s
[Epoch 1 Batch 600/2983] loss 5.28, ppl 197.01, throughput 937.78 samples/s
[Epoch 1 Batch 800/2983] loss 5.29, ppl 199.31, throughput 935.65 samples/s
[Epoch 1 Batch 1000/2983] loss 5.27, ppl 193.45, throughput 934.45 samples/s
[Epoch 1 Batch 1200/2983] loss 5.26, ppl 191.60, throughput 931.26 samples/s
[Epoch 1 Batch 1400/2983] loss 5.26, ppl 192.67, throughput 933.48 samples/s
[Epoch 1 Batch 1600/2983] loss 5.33, ppl 205.91, throughput 931.33 samples/s
[Epoch 1 Batch 1800/2983] loss 5.20, ppl 180.43, throughput 931.53 samples/s
[Epoch 1 Batch 2000/2983] loss 5.21, ppl 182.21, throughput 927.04 samples/s
[Epoch 1 Batch 2200/2983] loss 5.12, ppl 167.40, throughput 921.72 samples/s
[Epoch 1 Batch 2400/2983] loss 5.15, ppl 171.58, throughput 917.41 samples/s
[Epoch 1 Batch 2600/2983] loss 5.17, ppl 175.32, throughput 917.16 samples/s
[Epoch 1 Batch 2800/2983] loss 5.08, ppl 161.37, throughput 917.40 samples/s
[Epoch 1] throughput 930.43 samples/s
[Epoch 1] time cost 66.44s, valid loss 5.16, valid ppl 173.93
test loss 5.08, test ppl 161.46
[Epoch 2 Batch 200/2983] loss 5.13, ppl 169.57, throughput 915.27 samples/s
[Epoch 2 Batch 400/2983] loss 5.15, ppl 172.81, throughput 910.73 samples/s
[Epoch 2 Batch 600/2983] loss 4.98, ppl 145.24, throughput 914.23 samples/s
[Epoch 2 Batch 800/2983] loss 5.03, ppl 152.60, throughput 915.43 samples/s
[Epoch 2 Batch 1000/2983] loss 5.01, ppl 150.60, throughput 919.61 samples/s
[Epoch 2 Batch 1200/2983] loss 5.01, ppl 150.61, throughput 914.80 samples/s
[Epoch 2 Batch 1400/2983] loss 5.04, ppl 155.07, throughput 920.81 samples/s
[Epoch 2 Batch 1600/2983] loss 5.11, ppl 166.49, throughput 918.37 samples/s
[Epoch 2 Batch 1800/2983] loss 4.99, ppl 147.16, throughput 921.73 samples/s
[Epoch 2 Batch 2000/2983] loss 5.01, ppl 150.54, throughput 918.48 samples/s
[Epoch 2 Batch 2200/2983] loss 4.92, ppl 137.50, throughput 920.43 samples/s
[Epoch 2 Batch 2400/2983] loss 4.96, ppl 142.78, throughput 920.45 samples/s
[Epoch 2 Batch 2600/2983] loss 4.98, ppl 145.86, throughput 916.54 samples/s
[Epoch 2 Batch 2800/2983] loss 4.91, ppl 135.98, throughput 919.60 samples/s
[Epoch 2] throughput 917.83 samples/s
[Epoch 2] time cost 67.30s, valid loss 5.05, valid ppl 156.59
test loss 4.98, test ppl 146.01
Total training throughput 871.64 samples/s

Using your own dataset

When we train a language model, we fit to the statistics of a given dataset. While many papers focus on a few standard datasets, such as WikiText or the Penn Tree Bank, that’s just to provide a standard benchmark for the purpose of comparing models against one another. In general, for any given use case, you’ll want to train your own language model using a dataset of your own choice. Here, for demonstration, we’ll grab some .txt files corresponding to Sherlock Holmes novels.

We first download the new dataset.

[10]:
TRAIN_PATH = "./sherlockholmes.train.txt"
VALID_PATH = "./sherlockholmes.valid.txt"
TEST_PATH = "./sherlockholmes.test.txt"
PREDICT_PATH = "./tinyshakespeare/input.txt"
download(
    "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/sherlockholmes/sherlockholmes.train.txt",
    TRAIN_PATH,
    sha1_hash="d65a52baaf32df613d4942e0254c81cff37da5e8")
download(
    "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/sherlockholmes/sherlockholmes.valid.txt",
    VALID_PATH,
    sha1_hash="71133db736a0ff6d5f024bb64b4a0672b31fc6b3")
download(
    "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/sherlockholmes/sherlockholmes.test.txt",
    TEST_PATH,
    sha1_hash="b7ccc4778fd3296c515a3c21ed79e9c2ee249f70")
download(
    "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt",
    PREDICT_PATH,
    sha1_hash="04486597058d11dcc2c556b1d0433891eb639d2e")

print(glob.glob("sherlockholmes.*.txt"))
Downloading ./sherlockholmes.train.txt from https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/sherlockholmes/sherlockholmes.train.txt...
Downloading ./sherlockholmes.valid.txt from https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/sherlockholmes/sherlockholmes.valid.txt...
Downloading ./sherlockholmes.test.txt from https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/sherlockholmes/sherlockholmes.test.txt...
Downloading ./tinyshakespeare/input.txt from https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt...
['sherlockholmes.valid.txt', 'sherlockholmes.train.txt', 'sherlockholmes.test.txt']

Then we specify the tokenizer as well as batchify the dataset.

[11]:
import nltk
moses_tokenizer = nlp.data.SacreMosesTokenizer()

sherlockholmes_datasets = [
    nlp.data.CorpusDataset(
        'sherlockholmes.{}.txt'.format(name),
        sample_splitter=nltk.tokenize.sent_tokenize,
        tokenizer=moses_tokenizer,
        flatten=True,
        eos='<eos>') for name in ['train', 'valid', 'test']
]

sherlockholmes_train_data, sherlockholmes_val_data, sherlockholmes_test_data = [
    bptt_batchify(dataset) for dataset in sherlockholmes_datasets
]

We setup the evaluation to see whether our previous model trained on the other dataset does well on the new dataset.

[12]:
sherlockholmes_L = evaluate(model, sherlockholmes_val_data, batch_size,
                            context[0])
print('Best validation loss %.2f, test ppl %.2f' %
      (sherlockholmes_L, math.exp(sherlockholmes_L)))
Best validation loss 4.78, test ppl 118.87

Or we have the option of training the model on the new dataset with just one line of code.

[13]:
train(
    model,
    sherlockholmes_train_data, # This is your input training data, we leave batchifying and tokenizing as an exercise for the reader
    sherlockholmes_val_data,
    sherlockholmes_test_data, # This would be your test data, again left as an exercise for the reader
    epochs=3,
    lr=20)
[Epoch 0] throughput 897.04 samples/s
[Epoch 0] time cost 3.84s, valid loss 3.04, valid ppl 20.91
test loss 2.98, test ppl 19.61
[Epoch 1] throughput 916.07 samples/s
[Epoch 1] time cost 3.77s, valid loss 3.02, valid ppl 20.49
test loss 3.00, test ppl 20.13
[Epoch 2] throughput 920.58 samples/s
[Epoch 2] time cost 3.76s, valid loss 2.81, valid ppl 16.54
test loss 2.76, test ppl 15.86
Total training throughput 737.64 samples/s