[Download]

A Structured Self-attentive Sentence Embedding

After word embedding is applied to the representation of words, natural language processing(NLP) has been effectively improved in many ways. Along with the widespread use of word embedding, many techniques have been developed to express the semantics of sentences by words, such as: 1. The vector representation of multiple words in a sentence is concated or weighted to obtain a vector to represent the sentence. 2. Convolution(CNN) and maximum pooling(MaxPooling) on the matrix of all the word vectors of the sentence, using the final result to represent the semantics of the sentence. 3. Unroll the sentence according to the time step of the word, input the vector representation of each word into a recurrent neural network(RNN), and use the output of the last time step of the RNN as the semantic representation of the sentence.

The above method solves the problem of sentence meaning in a certain extent in many aspects. When concating is used in method one, if the word of the sentence is too long and the vector dimension of the word is slightly larger, then the vector dimension of the sentence will be particularly large, and the internal interaction between the words of the sentence is not taken into account. The use of weighted averaging is not accurate and does not adequately express the impact of each word on sentence semantics. Many useful word meanings may be lost in Method2. The third method selects the output of the last step. If the sentence is too long, the output of the last step does not accurately express the semantics of the sentence.

Based on the above mentioned method, Zhouhan Lin, Minwei Feng et al. published a paper A Structured Self-attentive Sentence Embedding[1] in 2017, proposed a method based on self-attention structure for sentence embedding and applied to user’s reviews classification, textual entailment and other tasks. In the end, good results were obtained.

In this note, we will use GluonNLP to reproduce the model structure in A Structured Self-attentive Sentence Embedding and apply it to Yelp Data’s review star rating data set for classification.

Data pipeline

Load The Reviews of Yelp Data

Yelp users’ review dataset is formatted as json. The original paper selected 500K documents as the training set, 2K as the validation set, and 2Kas the test set. For easier reproducilbility of the experiment, we subsampled 198K documents from this dataset as training set and 2K documents as validation set. Each sample in the data is a user’s comment, the language is English, and each comment category is marked 1-5, representing 5 different emotional colors.

In [3]:
# download the data zip from server.
data_url = 'http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/yelp_review_subset-167bb781.zip'
zip_path = mx.gluon.utils.download(data_url)

# unzip the zip file.
zip_file = zipfile.ZipFile(zip_path)
json_path = zip_file.extract(zip_file.namelist()[0])

## load json data.
with open(json_path, 'r', encoding='utf-8') as fr:
    data = json.load(fr)

# create a list of review a label paris.
dataset = [[text, int(label)] for text, label in zip(data['texts'], data['labels'])]

# randomly divide one percent from the training set as a verification set.
train_dataset, valid_dataset = nlp.data.train_valid_split(dataset, 0.01)
len(train_dataset), len(valid_dataset)
Downloading yelp_review_subset-167bb781.zip from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/yelp_review_subset-167bb781.zip...
Out[3]:
(198000, 2000)

Data Processing

The purpose of the following code is to process the raw data so that the processed data can be used for model training and prediction. We will use the SpacyTokenizer to split the document into tokens, ClipSequence to crop the comments to the specified length, and build a vocabulary based on the word frequency of the training data. Then we attach the Glove[2] pre-trained word vector to the vocabulary and convert each token into the corresponding word index in the vocabulary. Finally get the standardized training data set and verification data set

In [4]:
# tokenizer takes as input a string and outputs a list of tokens.
tokenizer = nlp.data.SpacyTokenizer('en')

# length_clip takes as input a list and outputs a list with maximum length 100.
length_clip = nlp.data.ClipSequence(100)

def preprocess(x):

    # convert the number of stars 1, 2, 3, 4, 5 to zero-based index, 0, 1, 2, 3, 4
    data, label = x[0], x[1]-1

    # clip the length of review words
    data = length_clip(tokenizer(data))
    return data, label

def get_length(x):
    return float(len(x[0]))

def preprocess_dataset(dataset):
    start = time.time()

    with mp.Pool() as pool:
        # Each sample is processed in an asynchronous manner.
        dataset = gluon.data.SimpleDataset(pool.map(preprocess, dataset))
        lengths = gluon.data.SimpleDataset(pool.map(get_length, dataset))
    end = time.time()

    print('Done! Tokenizing Time={:.2f}s, #Sentences={}'.format(end - start, len(dataset)))
    return dataset, lengths

# Preprocess the dataset
train_dataset, train_data_lengths = preprocess_dataset(train_dataset)
valid_dataset, valid_data_lengths = preprocess_dataset(valid_dataset)
Done! Tokenizing Time=13.87s, #Sentences=198000
Done! Tokenizing Time=2.54s, #Sentences=2000
In [5]:
# create vocab
train_seqs = [sample[0] for sample in train_dataset]
counter = nlp.data.count_tokens(list(itertools.chain.from_iterable(train_seqs)))

vocab = nlp.Vocab(counter, max_size=10000)

# load pre-trained embedding,Glove
embedding_weights = nlp.embedding.GloVe(source='glove.6B.300d')
vocab.set_embedding(embedding_weights)
print(vocab)

def token_to_idx(x):
    return vocab[x[0]], x[1]

# A token index or a list of token indices is returned according to the vocabulary.
with mp.Pool() as pool:
    train_dataset = pool.map(token_to_idx, train_dataset)
    valid_dataset = pool.map(token_to_idx, valid_dataset)
Vocab(size=10004, unk="<unk>", reserved="['<pad>', '<bos>', '<eos>']")

Bucketing and DataLoader

Since each sentence may have a different length, we need to use Pad to fill the sentences in a minibatch to equal lengths so that the data can be quickly tensored on the GPU. At the same time, we need to use Stack to stack the category tags of a batch of data. For convenience, we use Tuple to combine Pad and Stack.

In order to make the length of the sentence pad in each minibatch as small as possible, we should make the sentences with similar lengths in a batch as much as possible. In light of this, we consider constructing a sampler using FixedBucketSampler, which defines how the samples in a dataset will be iterated in a more economic way.

Finally, we use DataLoader to build a data loader for the training. dataset and validation dataset. The training dataset requires FixedBucketSampler, but the validation dataset dosen’t require the sampler.

In [6]:
batch_size = 64
bucket_num = 10
bucket_ratio = 0.5


def get_dataloader():
    # Construct the DataLoader Pad data, stack label and lengths
    batchify_fn = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(axis=0),
        nlp.data.batchify.Stack())

    # n this example, we use a FixedBucketSampler,
    # which assigns each data sample to a fixed bucket based on its length.
    batch_sampler = nlp.data.sampler.FixedBucketSampler(
        train_data_lengths,
        batch_size=batch_size,
        num_buckets=bucket_num,
        ratio=bucket_ratio,
        shuffle=True)
    print(batch_sampler.stats())

    # train_dataloader
    train_dataloader = gluon.data.DataLoader(
        dataset=train_dataset,
        batch_sampler=batch_sampler,
        batchify_fn=batchify_fn)
    # valid_dataloader
    valid_dataloader = gluon.data.DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        batchify_fn=batchify_fn)
    return train_dataloader, valid_dataloader

train_dataloader, valid_dataloader = get_dataloader()
FixedBucketSampler:
  sample_num=198000, batch_num=2922
  key=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
  cnt=[269, 4493, 13820, 14966, 15003, 14056, 12906, 11671, 10268, 100548]
  batch_size=[320, 160, 106, 80, 64, 64, 64, 64, 64, 64]

Model Structure

In the original paper, the representation of the sentence is: firstly, the sentence is disassembled into a list corresponding to the word, then the words are unrolled in order, and the word vector of each word is calculated as the input of each step of the bidirectional LSTM neural network layer[3]. Taking the output of each step of the bidirectional LSTM network layer, a matrix H is obtained. Suppose the hidden_dim of the bidirectional LSTM is u, the word length of the sentence is n, then the dimension of the last H is n-by-2u. For example, the sentence “this movie is amazing” would be represented as: image0

Attention is like when we are looking at things, we always give different importance to things in the scope of the perspective. For example, when we are communicating with people, our eyes will always pay more attention to the face of the communicator, not to the edge of other perspectives. So when we want to express the sentence, we can pay different attention to the output H of the bidirectional LSTM layer. image1

\[A = Softmax(W_{s2}tanh(W_{s1}H^T))\]

Here, Ws1 is a weight matrix with the shape: da-by-2u, where da is a hyperparameter. Ws2 is a weight matrix with the shape: r-by-da, where r is the number of different attentions you want to use.

When the attention matrix A and the output H of the LSTM are obtained, the final representation is obtained by M=AH.

We can first customize a layer of attention, specify the number of hidden nodes att_unit and the number of attention channels att_hops.

In [7]:
# custom attention layer
class SelfAttention(nn.HybridBlock):
    def __init__(self, att_unit, att_hops, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        with self.name_scope():
            self.ut_dense = nn.Dense(att_unit, activation='tanh', flatten=False)
            self.et_dense = nn.Dense(att_hops, activation=None, flatten=False)

    def hybrid_forward(self, F, x):
        # x shape: [batch_size, seq_len, embedding_width]
        # ut shape: [batch_size, seq_len, att_unit]
        ut = self.ut_dense(x)
        # et shape: [batch_size, seq_len, att_hops]
        et = self.et_dense(ut)

        # att shape: [batch_size,  att_hops, seq_len]
        att = F.softmax(F.transpose(et, axes=(0, 2, 1)), axis=-1)
        # output shape [batch_size, att_hops, embedding_width]
        output = F.batch_dot(att, x)

        return output, att

When the number of samples for labels are very unbalanced, applying different weights on different labels may improve the performance of the model.

In [8]:
class WeightedSoftmaxCE(nn.HybridBlock):
    def __init__(self, sparse_label=True, from_logits=False,  **kwargs):
        super(WeightedSoftmaxCE, self).__init__(**kwargs)
        with self.name_scope():
            self.sparse_label = sparse_label
            self.from_logits = from_logits

    def hybrid_forward(self, F, pred, label, class_weight, depth=None):
        if self.sparse_label:
            label = F.reshape(label, shape=(-1, ))
            label = F.one_hot(label, depth)
        if not self.from_logits:
            pred = F.log_softmax(pred, -1)

        weight_label = F.broadcast_mul(label, class_weight)
        loss = -F.sum(pred * weight_label, axis=-1)

        # return F.mean(loss, axis=0, exclude=True)
        return loss
In [9]:
class SelfAttentiveBiLSTM(nn.HybridBlock):
    def __init__(self, vocab_len, embsize, nhidden, nlayers, natt_unit, natt_hops, nfc, nclass,
                 drop_prob, pool_way, prune_p=None, prune_q=None, **kwargs):
        super(SelfAttentiveBiLSTM, self).__init__(**kwargs)
        with self.name_scope():
            self.embedding_layer = nn.Embedding(vocab_len, embsize)
            self.bilstm = rnn.LSTM(nhidden, num_layers=nlayers, dropout=drop_prob, bidirectional=True)
            self.att_encoder = SelfAttention(natt_unit, natt_hops)
            self.dense = nn.Dense(nfc, activation='tanh')
            self.output_layer = nn.Dense(nclass)

            self.dense_p, self.dense_q = None, None
            if all([prune_p, prune_q]):
                self.dense_p = nn.Dense(prune_p, activation='tanh', flatten=False)
                self.dense_q = nn.Dense(prune_q, activation='tanh', flatten=False)

            self.drop_prob = drop_prob
            self.pool_way = pool_way

    def hybrid_forward(self, F, inp):
        # input_embed: [batch, len, emsize]
        inp_embed = self.embedding_layer(inp)
        h_output = self.bilstm(F.transpose(inp_embed, axes=(1, 0, 2)))
        # att_output: [batch, att_hops, emsize]
        att_output, att = self.att_encoder(F.transpose(h_output, axes=(1, 0, 2)))

        dense_input = None
        if self.pool_way == 'flatten':
            dense_input = F.Dropout(F.flatten(att_output), self.drop_prob)
        elif self.pool_way == 'mean':
            dense_input = F.Dropout(F.mean(att_output, axis=1), self.drop_prob)
        elif self.pool_way == 'prune' and all([self.dense_p, self.dense_q]):
            # p_section: [batch, att_hops, prune_p]
            p_section = self.dense_p(att_output)
            # q_section: [batch, emsize, prune_q]
            q_section = self.dense_q(F.transpose(att_output, axes=(0, 2, 1)))
            dense_input = F.Dropout(F.concat(F.flatten(p_section), F.flatten(q_section), dim=-1), self.drop_prob)

        dense_out = self.dense(dense_input)
        output = self.output_layer(F.Dropout(dense_out, self.drop_prob))

        return output, att

Configure parameters and build models

The resulting M is a matrix, and the way to classify this matrix is flatten, mean or prune. Prune is a way of trimming parameters proposed in the original paper and has been implemented here.

In [10]:
vocab_len = len(vocab)
emsize = 300   # word embedding size
nhidden = 300    # lstm hidden_dim
nlayers = 2     # lstm layers
natt_unit = 300     # the hidden_units of attenion layer
natt_hops = 2    # the channels of attention
nfc = 512
nclass = 5

drop_prob = 0.5
pool_way = 'flatten'    # # The way to handle M
prune_p = None
prune_q = None

ctx = try_gpu()

model = SelfAttentiveBiLSTM(vocab_len, emsize, nhidden, nlayers,
                            natt_unit, natt_hops, nfc, nclass,
                            drop_prob, pool_way, prune_p, prune_q)

model.initialize(init=init.Xavier(), ctx=ctx)
model.hybridize()

# Attach a pre-trained glove word vector to the embedding layer
model.embedding_layer.weight.set_data(vocab.embedding.idx_to_vec)
# fixed the layer
model.embedding_layer.collect_params().setattr('grad_req', 'null')

Using r attention can improve the representation of sentences with different semantics, but if the value of each line in the attention matrix A (r-byn) is very close, that is, there is no difference between several attentions. Subsequently, in M = AH, the resulting M will contain a lot of redundant information. So in order to solve this problem, we should try to force A to ensure that the value of each line has obvious differences, that is, try to satisfy the diversity of attention. Therefore, a penalty can be used to achieve this goal.

\[P = ||(AA^T-I)||_F^2\]

It can be seen from the above formula that if the value of each row of A is more similar, the result of P will be larger, and the value of A is less similar for each row, and P is smaller. This means that when the r-focused diversity of A is larger, the smaller P is. So by including this penalty item with the Loss of the model, you can try to ensure the diversity of A.

In [11]:
def calculate_loss(x, y, model, loss, class_weight, penal_coeff):
    pred, att = model(x)
    if loss_name == 'sce':
        l = loss(pred, y)
    elif loss_name == 'wsce':
        l = loss(pred, y, class_weight, class_weight.shape[0])

    # penalty
    diversity_penalty = nd.batch_dot(att, nd.transpose(att, axes=(0, 2, 1))
                        ) - nd.eye(att.shape[1], ctx=att.context)
    l = l + penal_coeff * diversity_penalty.norm(axis=(1, 2))

    return pred, l
In [12]:
def one_epoch(data_iter, model, loss, trainer, ctx, is_train, epoch,
              penal_coeff=0.0, clip=None, class_weight=None, loss_name='wsce'):

    loss_val = 0.
    total_pred = []
    total_true = []
    n_batch = 0

    for batch_x, batch_y in data_iter:
        batch_x = batch_x.as_in_context(ctx)
        batch_y = batch_y.as_in_context(ctx)

        if is_train:
            with autograd.record():
                batch_pred, l = calculate_loss(batch_x, batch_y, model, loss, class_weight, penal_coeff)

            # backward calculate
            l.backward()

            # clip gradient
            clip_params = [p.data() for p in model.collect_params().values()]
            if clip is not None:
                norm = nd.array([0.0], ctx)
                for param in clip_params:
                    if param.grad is not None:
                        norm += (param.grad ** 2).sum()
                norm = norm.sqrt().asscalar()
                if norm > clip:
                    for param in clip_params:
                        if param.grad is not None:
                            param.grad[:] *= clip / norm

            # update parmas
            trainer.step(batch_x.shape[0])

        else:
            batch_pred, l = calculate_loss(batch_x, batch_y, model, loss, class_weight, penal_coeff)

        # keep result for metric
        batch_pred = nd.argmax(nd.softmax(batch_pred, axis=1), axis=1).asnumpy()
        batch_true = np.reshape(batch_y.asnumpy(), (-1, ))
        total_pred.extend(batch_pred.tolist())
        total_true.extend(batch_true.tolist())

        batch_loss = l.mean().asscalar()

        n_batch += 1
        loss_val += batch_loss

        # check the result of traing phase
        if is_train and n_batch % 400 == 0:
            print('epoch %d, batch %d, batch_train_loss %.4f, batch_train_acc %.3f' %
                  (epoch, n_batch, batch_loss, accuracy_score(batch_true, batch_pred)))

    # metric
    F1 = f1_score(np.array(total_true), np.array(total_pred), average='weighted')
    acc = accuracy_score(np.array(total_true), np.array(total_pred))
    loss_val /= n_batch

    if is_train:
        print('epoch %d, learning_rate %.5f \n\t train_loss %.4f, acc_train %.3f, F1_train %.3f, ' %
              (epoch, trainer.learning_rate, loss_val, acc, F1))
        # declay lr
        if epoch % 2 == 0:
            trainer.set_learning_rate(trainer.learning_rate * 0.9)
    else:
        print('\t valid_loss %.4f, acc_valid %.3f, F1_valid %.3f, ' % (loss_val, acc, F1))
In [13]:
def train_valid(data_iter_train, data_iter_valid, model, loss, trainer, ctx, nepochs,
                penal_coeff=0.0, clip=None, class_weight=None, loss_name='wsce'):

    for epoch in range(1, nepochs+1):
        start = time.time()
        # train
        is_train = True
        one_epoch(data_iter_train, model, loss, trainer, ctx, is_train,
                  epoch, penal_coeff, clip, class_weight, loss_name)

        # valid
        is_train = False
        one_epoch(data_iter_valid, model, loss, trainer, ctx, is_train,
                  epoch, penal_coeff, clip, class_weight, loss_name)
        end = time.time()
        print('time %.2f sec' % (end-start))
        print("*"*100)

Train

Now that we are training the model, we use WeightedSoftmaxCE to alleviate the problem of data category imbalance. We performed statistical analysis on the data in advance to get a set of class_weight.

In [14]:
class_weight = None
loss_name = 'wsce'
optim = 'adam'
lr = 0.001
penal_coeff = 0.1
clip = 0.5
nepochs = 4

trainer = gluon.Trainer(model.collect_params(), optim, {'learning_rate': lr})

if loss_name == 'sce':
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
elif loss_name == 'wsce':
    loss = WeightedSoftmaxCE()
    # the value of class_weight is obtained by counting data in advance. It can be seen as a hyperparameter.
    class_weight = nd.array([3.0, 5.3, 4.0, 2.0, 1.0], ctx=ctx)
In [15]:
# train and valid
train_valid(train_dataloader, valid_dataloader, model, loss, trainer, ctx, nepochs,
            penal_coeff=penal_coeff, clip=clip, class_weight=class_weight, loss_name=loss_name)
epoch 1, batch 400, batch_train_loss 2.2339, batch_train_acc 0.609
epoch 1, batch 800, batch_train_loss 3.7565, batch_train_acc 0.328
epoch 1, batch 1200, batch_train_loss 3.0790, batch_train_acc 0.453
epoch 1, batch 1600, batch_train_loss 2.0622, batch_train_acc 0.625
epoch 1, batch 2000, batch_train_loss 1.8870, batch_train_acc 0.625
epoch 1, batch 2400, batch_train_loss 1.7517, batch_train_acc 0.672
epoch 1, batch 2800, batch_train_loss 2.7688, batch_train_acc 0.562
epoch 1, learning_rate 0.00100
         train_loss 2.6618, acc_train 0.554, F1_train 0.566,
         valid_loss 2.5027, acc_valid 0.622, F1_valid 0.604,
time 230.11 sec
****************************************************************************************************
epoch 2, batch 400, batch_train_loss 1.7688, batch_train_acc 0.641
epoch 2, batch 800, batch_train_loss 2.6971, batch_train_acc 0.562
epoch 2, batch 1200, batch_train_loss 2.7339, batch_train_acc 0.609
epoch 2, batch 1600, batch_train_loss 1.6485, batch_train_acc 0.800
epoch 2, batch 2000, batch_train_loss 3.0782, batch_train_acc 0.516
epoch 2, batch 2400, batch_train_loss 2.1049, batch_train_acc 0.713
epoch 2, batch 2800, batch_train_loss 1.8566, batch_train_acc 0.672
epoch 2, learning_rate 0.00100
         train_loss 2.3834, acc_train 0.604, F1_train 0.615,
         valid_loss 2.3613, acc_valid 0.649, F1_valid 0.639,
time 226.31 sec
****************************************************************************************************
epoch 3, batch 400, batch_train_loss 2.7342, batch_train_acc 0.531
epoch 3, batch 800, batch_train_loss 2.2452, batch_train_acc 0.609
epoch 3, batch 1200, batch_train_loss 2.8032, batch_train_acc 0.547
epoch 3, batch 1600, batch_train_loss 2.4077, batch_train_acc 0.641
epoch 3, batch 2000, batch_train_loss 1.8940, batch_train_acc 0.719
epoch 3, batch 2400, batch_train_loss 3.1043, batch_train_acc 0.484
epoch 3, batch 2800, batch_train_loss 1.5103, batch_train_acc 0.625
epoch 3, learning_rate 0.00090
         train_loss 2.2506, acc_train 0.627, F1_train 0.637,
         valid_loss 2.2129, acc_valid 0.658, F1_valid 0.653,
time 185.20 sec
****************************************************************************************************
epoch 4, batch 400, batch_train_loss 2.7873, batch_train_acc 0.562
epoch 4, batch 800, batch_train_loss 1.6864, batch_train_acc 0.781
epoch 4, batch 1200, batch_train_loss 1.7567, batch_train_acc 0.672
epoch 4, batch 1600, batch_train_loss 2.5535, batch_train_acc 0.656
epoch 4, batch 2000, batch_train_loss 1.6658, batch_train_acc 0.613
epoch 4, batch 2400, batch_train_loss 2.0035, batch_train_acc 0.609
epoch 4, batch 2800, batch_train_loss 2.0003, batch_train_acc 0.719
epoch 4, learning_rate 0.00090
         train_loss 2.1562, acc_train 0.644, F1_train 0.654,
         valid_loss 2.2839, acc_valid 0.655, F1_valid 0.653,
time 203.42 sec
****************************************************************************************************

Predict

Now we will randomly input a sentence into the model and predict its emotional value tag. The range of emotional markers is 1-5, which corresponds to the degree of negative to positive.

In [16]:
input_ar = nd.array(vocab[['This', 'movie', 'is', 'amazing']], ctx=ctx).reshape((1, -1))
pred, att = model(input_ar)

label = np.argmax(nd.softmax(pred, axis=1).asnumpy(), axis=1) + 1
print(label)
print(att)
[5]

[[[0.17184225 0.2434033  0.18269765 0.40205672]
  [0.18523948 0.26206964 0.25372118 0.29896972]]]
<NDArray 1x2x4 @gpu(0)>

In order to intuitively feel the role of the attention mechanism, we visualize the output of the model’s attention on the predicted samples.

In [17]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

np.squeeze(att.asnumpy(), 0).shape
plt.figure(figsize=(8,1))
cmap = sns.diverging_palette(220, 10, as_cmap=True)
sns.heatmap(np.squeeze(att.asnumpy(), 0), cmap=cmap, annot=True,
            xticklabels=['This', 'movie', 'is', 'amazing'], yticklabels=['att0', 'att1'])
plt.show()
../../_images/examples_sentence_embedding_self_attentive_sentence_embedding_27_0.png

Save Model

In [18]:
model_dir = '../models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'self_att_bilstm_model')
model.export(model_path)
print('the structure and params of model have saved in:', model_path)

the structure and params of model have saved in: ../models/self_att_bilstm_model

Conclusion

Word embedding can effectively represent the semantic similarity between words, which brings many breakthroughs for complex natural language processing tasks. The attention mechanism can intuitively grasp the important semantic features in the sentence. The LSTM captures the word order relationship between words in a sentence. Through word embedding, LSTM and attention mechanisms work together to effectively represent the semantics of a sentence and apply it to many practical tasks.

GluonNLP provides us with many efficient and convenient tools to help us experiment quickly. This greatly simplifies the tedious work of natural language processing.