[Download]

Word Embeddings Training and Evaluation

Evaluating Word Embeddings

The previous example has introduced how to load pre-trained word embeddings from a set of sources included in the GluonNLP toolkit. It was shown how make use of the word vectors to find the top most similar words of a given words or to solve the analogy task.

Besides manually investigating similar words or the predicted analogous words, we can facilitate word embedding evaluation datasets to quantify the evaluation.

Datasets for the similarity task come with a list of word pairs together with a human similarity judgement. The task is to recover the order of most-similar to least-similar pairs.

Datasets for the analogy tasks supply a set of analogy quadruples of the form ‘a : b :: c : d’ and the task is to recover find the correct ‘d’ in as many cases as possible given just ‘a’, ‘b’, ‘c’. For instance, “man : woman :: son : daughter” is an analogy.

The GluonNLP toolkit includes a set of popular similarity and analogy task datasets as well as helpers for computing the evaluation scores. Here we show how to make use of them.

In [1]:
import warnings
warnings.filterwarnings('ignore')

import itertools
import time
import math
import logging
import random

import mxnet as mx
import gluonnlp as nlp
import numpy as np
from scipy import stats

# context = mx.cpu()  # Enable this to run on CPU
context = mx.gpu(0)  # Enable this to run on GPU

We first load pre-trained FastText word embeddings.

In [2]:
embedding = nlp.embedding.create('fasttext', source='crawl-300d-2M')

vocab = nlp.Vocab(nlp.data.Counter(embedding.idx_to_token))
vocab.set_embedding(embedding)

Word Similarity and Relatedness Task

Word embeddings should capture the relationsship between words in natural language. In the Word Similarity and Relatedness Task word embeddings are evaluated by comparing word similarity scores computed from a pair of words with human labels for the similarity or relatedness of the pair.

gluonnlp includes a number of common datasets for the Word Similarity and Relatedness Task. The included datasets are listed in the API documentation. We use several of them in the evaluation example below.

We first show a few samples from the WordSim353 dataset, to get an overall feeling of the Dataset structur

In [3]:
wordsim353 = nlp.data.WordSim353()
for i in range(15):
    print(*wordsim353[i], sep=', ')
announcement, warning, 6.0
population, development, 3.75
king, cabbage, 0.23
century, nation, 3.16
luxury, car, 6.47
disability, death, 5.47
mile, kilometer, 8.66
listing, category, 6.38
money, deposit, 7.73
fertility, egg, 6.69
planet, astronomer, 7.94
psychology, anxiety, 7.0
media, radio, 7.42
closet, clothes, 8.0
dollar, buck, 9.22

Similarity evaluator

The GluonNLP toolkit includes a WordEmbeddingSimilarity block, which predicts similarity score between word pairs given an embedding matrix.

In [4]:
evaluator = nlp.embedding.evaluation.WordEmbeddingSimilarity(
    idx_to_vec=vocab.embedding.idx_to_vec,
    similarity_function="CosineSimilarity")
evaluator.initialize(ctx=context)
evaluator.hybridize()

Evaluation: Running the task

In [5]:
words1, words2, scores = zip(*([vocab[d[0]], vocab[d[1]], d[2]] for d in wordsim353))
words1 = mx.nd.array(words1, ctx=context)
words2 = mx.nd.array(words2, ctx=context)

The similarities can be predicted by passing the two arrays of words through the evaluator. Thereby the ith word in words1 will be compared with the ith word in words2.

In [6]:
pred_similarity = evaluator(words1, words2)
print(pred_similarity[:5])

[ 0.30671987  0.19214539  0.18280022  0.29696473  0.2911745 ]
<NDArray 5 @gpu(0)>

We can evaluate the predicted similarities, and thereby the word embeddings, by computing the Spearman Rank Correlation between the predicted similarities and the groundtruth, human, similarity scores from the dataset:

In [7]:
sr = stats.spearmanr(pred_similarity.asnumpy(), np.array(scores))
print('Spearman rank correlation on {}: {}'.format(wordsim353.__class__.__name__,
                                                   sr.correlation.round(3)))
Spearman rank correlation on WordSim353: 0.79

Word Analogy Task

In the Word Analogy Task word embeddings are evaluated by inferring an analogous word D, which is related to a given word C in the same way as a given pair of words A, B are related.

gluonnlp includes a number of common datasets for the Word Analogy Task. The included datasets are listed in the API documentation. In this notebook we use the GoogleAnalogyTestSet dataset.

In [8]:
google_analogy = nlp.data.GoogleAnalogyTestSet()

We first demonstrate the structure of the dataset by printing a few examples

In [9]:
sample = []
print(('Printing every 1000st analogy question '
       'from the {} questions'
        'in the Google Analogy Test Set:').format(len(google_analogy)))
print('')
for i in range(0, 19544, 1000):
    print(*google_analogy[i])
    sample.append(google_analogy[i])
Printing every 1000st analogy question from the 19544 questionsin the Google Analogy Test Set:

athens greece baghdad iraq
baku azerbaijan dushanbe tajikistan
dublin ireland kathmandu nepal
lusaka zambia tehran iran
rome italy windhoek namibia
zagreb croatia astana kazakhstan
philadelphia pennsylvania tampa florida
wichita kansas shreveport louisiana
shreveport louisiana oxnard california
complete completely lucky luckily
comfortable uncomfortable clear unclear
good better high higher
young younger tight tighter
weak weakest bright brightest
slow slowing describe describing
ireland irish greece greek
feeding fed sitting sat
slowing slowed decreasing decreased
finger fingers onion onions
play plays sing sings

We restrict ourselves here to the first (most frequent) 300000 words of the pre-trained embedding as well as all tokens that occur in the evaluation datasets as possible answers to the analogy questions.

In [10]:
import itertools

most_freq = 300000
counter = nlp.data.utils.Counter(embedding.idx_to_token[:most_freq])
google_analogy_tokens = set(itertools.chain.from_iterable((d[0], d[1], d[2], d[3]) for d in google_analogy))
counter.update(t for t in google_analogy_tokens if t in embedding)

vocab = nlp.vocab.Vocab(counter)
vocab.set_embedding(embedding)

print("Using most frequent {} + {} extra words".format(most_freq, len(vocab) - most_freq))


google_analogy_subset = [
    d for i, d in enumerate(google_analogy) if
    d[0] in vocab and d[1] in vocab and d[2] in vocab and d[3] in vocab
]
print('Dropped {} pairs from {} as they were OOV.'.format(
    len(google_analogy) - len(google_analogy_subset),
    len(google_analogy)))

google_analogy_coded = [[vocab[d[0]], vocab[d[1]], vocab[d[2]], vocab[d[3]]]
                 for d in google_analogy_subset]
google_analogy_coded_batched = mx.gluon.data.DataLoader(
    google_analogy_coded, batch_size=256)
Using most frequent 300000 + 96 extra words
Dropped 1781 pairs from 19544 as they were OOV.
In [11]:
evaluator = nlp.embedding.evaluation.WordEmbeddingAnalogy(
    idx_to_vec=vocab.embedding.idx_to_vec,
    exclude_question_words=True,
    analogy_function="ThreeCosMul")
evaluator.initialize(ctx=context)
evaluator.hybridize()
In [12]:
acc = mx.metric.Accuracy()

for i, batch in enumerate(google_analogy_coded_batched):
    batch = batch.as_in_context(context)
    words1, words2, words3, words4 = (batch[:, 0], batch[:, 1],
                                      batch[:, 2], batch[:, 3])
    pred_idxs = evaluator(words1, words2, words3)
    acc.update(pred_idxs[:, 0], words4.astype(np.float32))

print('Accuracy on %s: %s'% (google_analogy.__class__.__name__, acc.get()[1].round(3)))
Accuracy on GoogleAnalogyTestSet: 0.772

Training word embeddings

Next to making it easy to work with pre-trained word embeddings, gluonnlp also provides everything needed to train your own embeddings. Datasets as well as model definitions are included.

Loading the training data

We first load the Text8 corpus from the Large Text Compression Benchmark which includes the first 100 MB of cleaned text from the English Wikipedia. We follow the common practice of splitting every 10’000 tokens to obtain “sentences” for embedding training.

In [13]:
dataset = nlp.data.Text8(segment='train')
print('# sentences:', len(dataset))
for sentence in dataset[:3]:
    print('# tokens:', len(sentence), sentence[:5])
# sentences: 1701
# tokens: 10000 ['anarchism', 'originated', 'as', 'a', 'term']
# tokens: 10000 ['reciprocity', 'qualitative', 'impairments', 'in', 'communication']
# tokens: 10000 ['with', 'the', 'aegis', 'of', 'zeus']

We then build a vocabulary of all the tokens in the dataset that occur more than 5 times and replace the words with their indices.

In [14]:
counter = nlp.data.count_tokens(itertools.chain.from_iterable(dataset))
vocab = nlp.Vocab(
    counter,
    unknown_token=None,
    padding_token=None,
    bos_token=None,
    eos_token=None,
    min_freq=5)

def code(s):
    return [vocab[t] for t in s if t in vocab]

coded_dataset = dataset.transform(code, lazy=False)

Some words such as “the”, “a”, and “in” are very frequent. One important trick applied when training word2vec is to subsample the dataset according to the token frequencies. [1] proposes to discard individual occurences of words from the dataset with probability

\[P(w_i) = 1 - \sqrt{\frac{t}{f(w_i)}}\]

where \(f(w_i)\) is the frequency with which a word is observed in a dataset and \(t\) is a subsampling constant typically chosen around \(10^{-5}\).

[1] Mikolov, Tomas, et al. “Distributed representations of words and phrases and their compositionality.” Advances in neural information processing systems. 2013.

In [15]:
subsampling_constant = 1e-5

idx_to_count = [counter[w] for w in vocab.idx_to_token]
total_count = sum(idx_to_count)
idx_to_pdiscard = [
    1 - math.sqrt(subsampling_constant / (count / total_count))
    for count in idx_to_count
]


def subsample(s):
    return [
        t for t, r in zip(s, np.random.uniform(0, 1, size=len(s)))
        if r > idx_to_pdiscard[t]
    ]


subsampled_dataset = coded_dataset.transform(subsample, lazy=False)

print('# tokens for sentences in coded_dataset:')
for i in range(3):
    print(len(coded_dataset[i]), coded_dataset[i][:5])

print('\n# tokens for sentences in subsampled_dataset:')
for i in range(3):
    print(len(subsampled_dataset[i]), subsampled_dataset[i][:5])
# tokens for sentences in coded_dataset:
9895 [5233, 3083, 11, 5, 194]
9858 [18214, 17356, 36672, 4, 1753]
9926 [23, 0, 19754, 1, 4829]

# tokens for sentences in subsampled_dataset:
2958 [3083, 127, 741, 10619, 27497]
2816 [18214, 17356, 36672, 13001, 3]
2771 [19754, 1799, 8712, 16334, 6690]

Model definition

gluonnlp provides model definitions for popular embedding models as Gluon Blocks. Here we show how to train them with the Skip-Gram objective, a simple and popular embedding training objective. It was introduced by “Tomas Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. Efficient estimation of word representations in vector space. ICLR Workshop , 2013.”

The Skip-Gram objective trains word vectors such that the word vector of a word at some position in a sentence can best predict the surrounding words. We call these words center and context words.

Skip-Gram and picture from “Tomas Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. Efficient estimation of word representations in vector space. ICLR Workshop , 2013.”

For the Skip-Gram objective, we initialize two embedding models: embedding and embedding_out. embedding is used to look up embeddings for the center words. embedding_out is used for the context words.

The weights of embedding are the final word embedding weights.

In [44]:
emsize = 300

embedding = nlp.model.train.SimpleEmbeddingModel(
    token_to_idx=vocab.token_to_idx,
    embedding_size=emsize,
    weight_initializer=mx.init.Uniform(scale=1 / emsize))
embedding_out = nlp.model.train.SimpleEmbeddingModel(
    token_to_idx=vocab.token_to_idx,
    embedding_size=emsize,
    weight_initializer=mx.init.Uniform(scale=1 / emsize))

embedding.initialize(ctx=context)
embedding_out.initialize(ctx=context)
embedding.hybridize(static_alloc=True)
embedding_out.hybridize(static_alloc=True)

params = list(embedding.collect_params().values()) + \
    list(embedding_out.collect_params().values())
trainer = mx.gluon.Trainer(params, 'adagrad', dict(learning_rate=0.05))

Before we start training, let’s examine the quality of our randomly initialized embeddings:

In [45]:
def norm_vecs_by_row(x):
    return x / (mx.nd.sum(x * x, axis=1) + 1e-10).sqrt().reshape((-1, 1))


def get_k_closest_tokens(vocab, embedding, k, word):
    word_vec = embedding(mx.nd.array([vocab.token_to_idx[word]],
                                     ctx=context)).reshape((-1, 1))
    vocab_vecs = norm_vecs_by_row(embedding.embedding.weight.data())
    dot_prod = mx.nd.dot(vocab_vecs, word_vec)
    indices = mx.nd.topk(
        dot_prod.reshape((len(vocab.idx_to_token), )),
        k=k + 1,
        ret_typ='indices')
    indices = [int(i.asscalar()) for i in indices]
    result = [vocab.idx_to_token[i] for i in indices[1:]]
    print('closest tokens to "%s": %s' % (word, ", ".join(result)))


example_token = "data"
get_k_closest_tokens(vocab, embedding, 10, example_token)
closest tokens to "data": icarius, pedagogy, weasels, transgression, flavonoids, moguls, unvisited, basse, nonspecific, scrolling

Training objective

Naive objective

To naively maximize the Skip-Gram objective, if we sample a center word we need to compute a prediction for every other word in the vocabulary if it occurs in the context of the center word or not. We can then backpropagate and update the parameters to make the prediction of the correct context words more likely and of all other words less likely.

However, this naive method is computationally very expensive as it requires computing a Softmax function over all words in the vocabulary. Instead, “Tomas Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. Efficient estimation of word representations in vector space. ICLR Workshop , 2013.” introduced Negative Sampling.

Negative sampling

Negative Sampling means that instead of using a small number of correct (or positive) context and all other (negative) words to compute the loss and update the parameters we may choose a small, constant number of negative words at random. Negative words are choosen randomly based on their frequency in the training corpus. It is recommend to smoothen the frequency distribution by the factor 0.75.

We can use the UnigramCandidateSampler to sample tokens by some unigram weights.

In [46]:
num_negatives = 5
weights = mx.nd.array(idx_to_count)**0.75
negatives_sampler = nlp.data.UnigramCandidateSampler(weights)

Center and context words

We can use EmbeddingCenterContextBatchify to transform a corpus into batches of center and context words.

In [47]:
batch_size = 2048
window_size = 5
batchify = nlp.data.batchify.EmbeddingCenterContextBatchify(batch_size=batch_size, window_size=window_size)
batches = batchify(subsampled_dataset)

To compute the loss with negative sampling we use SigmoidBinaryCrossEntropyLoss.

In [48]:
loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
In [49]:
def remove_accidental_hits(candidates, true_samples):
    """Compute a candidates_mask surpressing accidental hits.

    Accidental hits are candidates that occur in the same batch dimension of
    true_samples.

    """
    candidates_np = candidates.asnumpy()
    true_samples_np = true_samples.asnumpy()

    candidates_mask = np.ones(candidates.shape, dtype=np.bool_)
    for j in range(true_samples.shape[1]):
        candidates_mask &= ~(candidates_np == true_samples_np[:, j:j + 1])

    return candidates, mx.nd.array(candidates_mask, ctx=candidates.context)


def skipgram_batch(data):
    """Create a batch for Skipgram training objective."""
    centers, word_context, word_context_mask = data
    assert len(centers.shape) == 2
    negatives_shape = (len(word_context), 2 * window_size * num_negatives)
    negatives, negatives_mask = remove_accidental_hits(
        negatives_sampler(negatives_shape), word_context)
    context_negatives = mx.nd.concat(word_context, negatives, dim=1)
    masks = mx.nd.concat(word_context_mask, negatives_mask, dim=1)
    labels = mx.nd.concat(word_context_mask, mx.nd.zeros_like(negatives), dim=1)
    return (centers.as_in_context(context),
            context_negatives.as_in_context(context),
            masks.as_in_context(context),
            labels.as_in_context(context))
In [50]:
def train_embedding(num_epochs):
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        train_l_sum = 0
        num_samples = 0
        for i, data in enumerate(batches):
            (center, context_and_negative, mask,
             label) = skipgram_batch(data)
            with mx.autograd.record():
                emb_in = embedding(center)
                emb_out = embedding_out(context_and_negative)
                pred = mx.nd.batch_dot(emb_in, emb_out.swapaxes(1, 2))
                l = (loss(pred.reshape(label.shape), label, mask) *
                     mask.shape[1] / mask.sum(axis=1))
            l.backward()
            trainer.step(1)
            train_l_sum += l.sum()
            num_samples += center.shape[0]
            if i % 500 == 0:
                mx.nd.waitall()
                wps = num_samples / (time.time() - start_time)
                print('epoch %d, time %.2fs, iteration %d, throughput=%.2fK wps'
                      % (epoch, time.time() - start_time, i, wps / 1000))

        print('epoch %d, time %.2fs, train loss %.2f'
              % (epoch, time.time() - start_time,
                 train_l_sum.asscalar() / num_samples))
        get_k_closest_tokens(vocab, embedding, 10, example_token)
        print("")
In [51]:
train_embedding(num_epochs=5)
epoch 1, time 0.38s, iteration 0, throughput=5.42K wps
epoch 1, time 6.49s, iteration 500, throughput=158.19K wps
epoch 1, time 12.61s, iteration 1000, throughput=162.62K wps
epoch 1, time 18.71s, iteration 1500, throughput=164.30K wps
epoch 1, time 24.86s, iteration 2000, throughput=164.86K wps
epoch 1, time 27.28s, train loss 0.35
closest tokens to "data": storage, extensions, architectures, accessing, applications, packages, interoperability, identifier, interfaces, executable

epoch 2, time 0.37s, iteration 0, throughput=5.53K wps
epoch 2, time 6.52s, iteration 500, throughput=157.36K wps
epoch 2, time 12.67s, iteration 1000, throughput=161.75K wps
epoch 2, time 18.80s, iteration 1500, throughput=163.53K wps
epoch 2, time 24.89s, iteration 2000, throughput=164.64K wps
epoch 2, time 27.43s, train loss 0.31
closest tokens to "data": metadata, protocols, addressable, routing, asynchronous, identifier, dynamically, encoding, packet, optimization

epoch 3, time 0.37s, iteration 0, throughput=5.53K wps
epoch 3, time 6.47s, iteration 500, throughput=158.53K wps
epoch 3, time 12.58s, iteration 1000, throughput=162.95K wps
epoch 3, time 18.67s, iteration 1500, throughput=164.64K wps
epoch 3, time 24.76s, iteration 2000, throughput=165.48K wps
epoch 3, time 27.32s, train loss 0.30
closest tokens to "data": addressable, terabytes, indexing, metadata, encapsulated, encode, lossless, storage, coding, encoding

epoch 4, time 0.37s, iteration 0, throughput=5.55K wps
epoch 4, time 6.45s, iteration 500, throughput=158.97K wps
epoch 4, time 12.55s, iteration 1000, throughput=163.40K wps
epoch 4, time 18.64s, iteration 1500, throughput=164.94K wps
epoch 4, time 24.72s, iteration 2000, throughput=165.79K wps
epoch 4, time 27.19s, train loss 0.29
closest tokens to "data": addressable, indexing, terabytes, encapsulated, lossless, storing, storage, lossy, coding, encode

epoch 5, time 0.37s, iteration 0, throughput=5.58K wps
epoch 5, time 6.46s, iteration 500, throughput=158.89K wps
epoch 5, time 12.55s, iteration 1000, throughput=163.38K wps
epoch 5, time 18.64s, iteration 1500, throughput=164.92K wps
epoch 5, time 24.72s, iteration 2000, throughput=165.79K wps
epoch 5, time 27.23s, train loss 0.28
closest tokens to "data": lossless, addressable, terabytes, lossy, storage, storing, decoder, indexing, encapsulated, coding

Evaluation of trained embedding

As we have only obtained word vectors for words that occured in the training corpus, we filter the evaluation dataset and exclude out of vocabulary words.

In [52]:
words1, words2, scores = zip(*([vocab[d[0]], vocab[d[1]], d[2]]
    for d in wordsim353  if d[0] in vocab and d[1] in vocab))
words1 = mx.nd.array(words1, ctx=context)
words2 = mx.nd.array(words2, ctx=context)

We create a new TokenEmbedding object and set the embedding vectors for the words we care about for evaluation.

In [53]:
token_embedding = nlp.embedding.TokenEmbedding(unknown_token=None, allow_extend=True)
token_embedding[vocab.idx_to_token] = embedding[vocab.idx_to_token]

evaluator = nlp.embedding.evaluation.WordEmbeddingSimilarity(
    idx_to_vec=token_embedding.idx_to_vec,
    similarity_function="CosineSimilarity")
evaluator.initialize(ctx=context)
evaluator.hybridize()
In [54]:
pred_similarity = evaluator(words1, words2)
sr = stats.spearmanr(pred_similarity.asnumpy(), np.array(scores))
print('Spearman rank correlation on {} pairs of {} (total {}): {}'.format(
    len(words1), wordsim353.__class__.__name__, len(wordsim353), sr.correlation.round(3)))
Spearman rank correlation on 332 pairs of WordSim353 (total 352): 0.627

Unknown token handling and subword information

Sometimes we may run into a word for which the embedding does not include a word vector. While the vocab object is happy to replace it with a special index for unknown tokens.

In [55]:
print('Is "hello" known? ', 'hello' in vocab)
print('Is "likelyunknown" known? ', 'likelyunknown' in vocab)
Is "hello" known?  True
Is "likelyunknown" known?  False

Some embedding models such as the FastText model support computing word vectors for unknown words by taking into account their subword units.

  • Tomas Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. Efficient estimation of word representations in vector space. ICLR Workshop , 2013.

Training word embeddings with subword information

gluonnlp provides the concept of a SubwordFunction which maps words to a list of indices representing their subword. Possible SubwordFunctions include mapping a word to the sequence of it’s characters/bytes or hashes of all its ngrams.

FastText models use a hash function to map each ngram of a word to a number in range [0, num_subwords). We include the same hash function.

Concept of a SubwordFunction

In [56]:
subword_function = nlp.vocab.create_subword_function(
    'NGramHashes', ngrams=[3, 4, 5, 6], num_subwords=500000)

idx_to_subwordidxs = subword_function(vocab.idx_to_token)
for word, subwords in zip(vocab.idx_to_token[:3], idx_to_subwordidxs[:3]):
    print('<'+word+'>', subwords, sep = '\t')
<the>     [151151, 409726, 148960, 361980, 60934, 316280]
<of>      [497102, 164528, 228930]
<and>     [378080, 235020, 30390, 395046, 119624, 125443]

As words are of varying length, we have to pad the lists of subwords to obtain a batch. To distinguish padded values from valid subword indices we use a mask. We first pad the subword arrays with -1, compute the mask and change the -1 entries to some valid subword index (here 0).

In [57]:
subword_padding = nlp.data.batchify.Pad(pad_val=-1)

subwords = subword_padding(idx_to_subwordidxs[:3])
subwords_mask = subwords != -1
subwords += subwords == -1  # -1 is invalid. Change to 0
print(subwords)
print(subwords_mask)

[[ 151151.  409726.  148960.  361980.   60934.  316280.]
 [ 497102.  164528.  228930.       0.       0.       0.]
 [ 378080.  235020.   30390.  395046.  119624.  125443.]]
<NDArray 3x6 @cpu_shared(0)>

[[ 1.  1.  1.  1.  1.  1.]
 [ 1.  1.  1.  0.  0.  0.]
 [ 1.  1.  1.  1.  1.  1.]]
<NDArray 3x6 @cpu(0)>

The model

Instead of the SimpleEmbeddingModel we now train a FasttextEmbeddingModel Block which can combine the word and subword information.

In [58]:
emsize = 300
embedding = nlp.model.train.FasttextEmbeddingModel(
    token_to_idx=vocab.token_to_idx,
    subword_function=subword_function,
    embedding_size=emsize,
    weight_initializer=mx.init.Uniform(scale=1 / emsize))
embedding_out = nlp.model.train.SimpleEmbeddingModel(
    token_to_idx=vocab.token_to_idx,
    embedding_size=emsize,
    weight_initializer=mx.init.Uniform(scale=1 / emsize))
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()

embedding.initialize(ctx=context)
embedding_out.initialize(ctx=context)
embedding.hybridize(static_alloc=True)
embedding_out.hybridize(static_alloc=True)

params = list(embedding.collect_params().values()) + \
    list(embedding_out.collect_params().values())
trainer = mx.gluon.Trainer(params, 'adagrad', dict(learning_rate=0.05))

Training

Compared to training the SimpleEmbeddingModel, we now also look up the subwords of each center word in the batch and pass the subword infor

In [59]:
from gluonnlp.base import numba_jitclass, numba_types, numba_prange

@numba_jitclass([('idx_to_subwordidxs',
                  numba_types.List(numba_types.int_[::1]))])
class SubwordLookup(object):
    """Just-in-time compiled helper class for fast, padded subword lookup.

    SubwordLookup holds a mapping from token indices to variable length subword
    arrays and allows fast access to padded and masked batches of subwords
    given a list of token indices.

    Parameters
    ----------
    length : int
         Number of tokens for which to hold subword arrays.

    """
    def __init__(self, length):
        self.idx_to_subwordidxs = [
            np.arange(1).astype(np.int_) for _ in range(length)
        ]

    def set(self, i, subwords):
        """Set the subword array of the i-th token."""
        self.idx_to_subwordidxs[i] = subwords

    def get(self, indices):
        """Get a padded array and mask of subwords for specified indices."""
        subwords = [self.idx_to_subwordidxs[i] for i in indices]
        lengths = np.array([len(s) for s in subwords])
        length = np.max(lengths)
        subwords_arr = np.zeros((len(subwords), length))
        mask = np.zeros((len(subwords), length))
        for i in numba_prange(len(subwords)):
            s = subwords[i]
            subwords_arr[i, :len(s)] = s
            mask[i, :len(s)] = 1
        return subwords_arr, mask

subword_lookup = SubwordLookup(len(idx_to_subwordidxs))
for i, subwords in enumerate(idx_to_subwordidxs):
    subword_lookup.set(i, np.array(subwords, dtype=np.int_))

def skipgram_fasttext_batch(data):
    """Create a batch for Skipgram training objective."""
    centers, word_context, word_context_mask = data
    assert len(centers.shape) == 2
    negatives_shape = (len(word_context), 2 * window_size * num_negatives)
    negatives, negatives_mask = remove_accidental_hits(
        negatives_sampler(negatives_shape), word_context)
    context_negatives = mx.nd.concat(word_context, negatives, dim=1)
    masks = mx.nd.concat(word_context_mask, negatives_mask, dim=1)
    labels = mx.nd.concat(word_context_mask, mx.nd.zeros_like(negatives), dim=1)

    unique, inverse_unique_indices = np.unique(centers.asnumpy(),
                                               return_inverse=True)
    inverse_unique_indices = mx.nd.array(inverse_unique_indices,
                                         ctx=context)
    subwords, subwords_mask = subword_lookup.get(unique.astype(int))

    return (centers.as_in_context(context),
            context_negatives.as_in_context(context),
            masks.as_in_context(context),
            labels.as_in_context(context),
            mx.nd.array(subwords, ctx=context),
            mx.nd.array(subwords_mask, ctx=context),
            inverse_unique_indices)
In [61]:
def train_fasttext_embedding(num_epochs):
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        train_l_sum = 0
        num_samples = 0
        for i, data in enumerate(batches):
            (center, context_negatives, mask, label, subwords,
             subwords_mask, inverse_unique_indices) = skipgram_fasttext_batch(data)
            with mx.autograd.record():
                emb_in = embedding(center, subwords,
                   subwordsmask=subwords_mask,
                   words_to_unique_subwords_indices=
                   inverse_unique_indices)
                emb_out = embedding_out(context_negatives, mask)
                pred = mx.nd.batch_dot(emb_in, emb_out.swapaxes(1, 2))
                l = (loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1))
            l.backward()
            trainer.step(1)
            train_l_sum += l.sum()
            num_samples += center.shape[0]
            if i % 500 == 0:
                mx.nd.waitall()
                wps = num_samples / (time.time() - start_time)
                print('epoch %d, time %.2fs, iteration %d, throughput=%.2fK wps'
                      % (epoch, time.time() - start_time, i, wps / 1000))

        print('epoch %d, time %.2fs, train loss %.2f'
              % (epoch, time.time() - start_time,
                 train_l_sum.asscalar() / num_samples))
        print("")
In [62]:
train_fasttext_embedding(num_epochs=5)
epoch 1, time 1.15s, iteration 0, throughput=1.79K wps
epoch 1, time 31.93s, iteration 500, throughput=32.13K wps
epoch 1, time 62.05s, iteration 1000, throughput=33.04K wps
epoch 1, time 91.97s, iteration 1500, throughput=33.42K wps
epoch 1, time 121.87s, iteration 2000, throughput=33.63K wps
epoch 1, time 125.26s, train loss 0.29

epoch 2, time 0.42s, iteration 0, throughput=4.90K wps
epoch 2, time 30.46s, iteration 500, throughput=33.68K wps
epoch 2, time 60.69s, iteration 1000, throughput=33.78K wps
epoch 2, time 90.62s, iteration 1500, throughput=33.92K wps
epoch 2, time 121.22s, iteration 2000, throughput=33.81K wps
epoch 2, time 124.74s, train loss 0.27

epoch 3, time 0.43s, iteration 0, throughput=4.76K wps
epoch 3, time 30.75s, iteration 500, throughput=33.37K wps
epoch 3, time 60.96s, iteration 1000, throughput=33.63K wps
epoch 3, time 90.88s, iteration 1500, throughput=33.82K wps
epoch 3, time 120.96s, iteration 2000, throughput=33.88K wps
epoch 3, time 124.37s, train loss 0.27

epoch 4, time 0.43s, iteration 0, throughput=4.76K wps
epoch 4, time 30.54s, iteration 500, throughput=33.59K wps
epoch 4, time 60.52s, iteration 1000, throughput=33.87K wps
epoch 4, time 90.60s, iteration 1500, throughput=33.93K wps
epoch 4, time 121.34s, iteration 2000, throughput=33.77K wps
epoch 4, time 125.48s, train loss 0.26

epoch 5, time 0.43s, iteration 0, throughput=4.75K wps
epoch 5, time 31.41s, iteration 500, throughput=32.66K wps
epoch 5, time 62.20s, iteration 1000, throughput=32.96K wps
epoch 5, time 93.28s, iteration 1500, throughput=32.96K wps
epoch 5, time 123.72s, iteration 2000, throughput=33.12K wps
epoch 5, time 127.43s, train loss 0.26

Evaluation

Thanks to the subword support of the FasttextEmbeddingModel we can now evaluate on all words in the evaluation dataset, not only the ones that we observed during training (the SimpleEmbeddingModel only provides vectors for words observed at training).

We first find the all tokens in the evaluation dataset and then convert the FasttextEmbeddingModel to a TokenEmbedding with exactly those tokens.

In [63]:
wordsim353_tokens  = list(set(itertools.chain.from_iterable((d[0], d[1]) for d in wordsim353)))
token_embedding = nlp.embedding.TokenEmbedding(unknown_token=None, allow_extend=True)
token_embedding[wordsim353_tokens] = embedding[wordsim353_tokens]

print('There are', len(wordsim353_tokens), 'unique tokens in WordSim353')
print('The imputed TokenEmbedding has shape', token_embedding.idx_to_vec.shape)
There are 437 unique tokens in WordSim353
The imputed TokenEmbedding has shape (437, 300)
In [64]:
evaluator = nlp.embedding.evaluation.WordEmbeddingSimilarity(
    idx_to_vec=token_embedding.idx_to_vec,
    similarity_function="CosineSimilarity")
evaluator.initialize(ctx=context)
evaluator.hybridize()
In [65]:
words1, words2, scores = zip(*([token_embedding.token_to_idx[d[0]],
                                token_embedding.token_to_idx[d[1]],
                                d[2]] for d in wordsim353))
words1 = mx.nd.array(words1, ctx=context)
words2 = mx.nd.array(words2, ctx=context)
In [66]:
pred_similarity = evaluator(words1, words2)
sr = stats.spearmanr(pred_similarity.asnumpy(), np.array(scores))
print('Spearman rank correlation on {} pairs of {}: {}'.format(
    len(words1), wordsim353.__class__.__name__, sr.correlation.round(3)))
Spearman rank correlation on 352 pairs of WordSim353: 0.569