Source code for gluonnlp.data.classification

# Copyright 2018 The Google AI Language Team Authors and DMLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GLUE classification/regression datasets."""


__all__ = [
    'MRPCTask', 'QQPTask', 'QNLITask', 'RTETask', 'STSBTask',
    'CoLATask', 'MNLITask', 'WNLITask', 'SSTTask', 'XNLITask', 'get_task'
]

from copy import copy
from mxnet.metric import Accuracy, F1, MCC, PearsonCorrelation, CompositeEvalMetric
from .glue import GlueCoLA, GlueSST2, GlueSTSB, GlueMRPC
from .glue import GlueQQP, GlueRTE, GlueMNLI, GlueQNLI, GlueWNLI
from .baidu_ernie_data import BaiduErnieXNLI, BaiduErnieChnSentiCorp, BaiduErnieLCQMC


class GlueTask:
    """Abstract GLUE task class.

    Parameters
    ----------
    class_labels : list of str, or None
        Classification labels of the task.
        Set to None for regression tasks with continuous real values.
    metrics : list of EValMetric
        Evaluation metrics of the task.
    is_pair : bool
        Whether the task deals with sentence pairs or single sentences.
    label_alias : dict
        label alias dict, some different labels in dataset actually means
        the same. e.g.: {'contradictory':'contradiction'} means contradictory
        and contradiction label means the same in dataset, they will get
        the same class id.
    """
    def __init__(self, class_labels, metrics, is_pair, label_alias=None):
        self.class_labels = class_labels
        self.metrics = metrics
        self.is_pair = is_pair
        self.label_alias = label_alias

    def get_dataset(self, segment='train'):
        """Get the corresponding dataset for the task.

        Parameters
        ----------
        segment : str, default 'train'
            Dataset segments.

        Returns
        -------
        TSVDataset : the dataset of target segment.
        """
        raise NotImplementedError()

    def dataset_train(self):
        """Get the training segment of the dataset for the task.

        Returns
        -------
        tuple of str, TSVDataset : the segment name, and the dataset.
        """
        return 'train', self.get_dataset(segment='train')

    def dataset_dev(self):
        """Get the dev segment of the dataset for the task.

        Returns
        -------
        tuple of (str, TSVDataset), or list of tuple : the segment name, and the dataset.
        """
        return 'dev', self.get_dataset(segment='dev')

    def dataset_test(self):
        """Get the test segment of the dataset for the task.

        Returns
        -------
        tuple of (str, TSVDataset), or list of tuple : the segment name, and the dataset.
        """
        return 'test', self.get_dataset(segment='test')

[docs]class MRPCTask(GlueTask): """The MRPC task on GlueBenchmark. Examples -------- >>> MRPC = MRPCTask() >>> MRPC.class_labels ['0', '1'] >>> type(MRPC.metrics.get_metric(0)) <class 'mxnet.metric.Accuracy'> >>> type(MRPC.metrics.get_metric(1)) <class 'mxnet.metric.F1'> >>> MRPC.dataset_train()[0] -etc- 'train' >>> len(MRPC.dataset_train()[1]) 3668 >>> MRPC.dataset_dev()[0] 'dev' >>> len(MRPC.dataset_dev()[1]) 408 >>> MRPC.dataset_test()[0] -etc- 'test' >>> len(MRPC.dataset_test()[1]) 1725 """ def __init__(self): is_pair = True class_labels = ['0', '1'] metric = CompositeEvalMetric() metric.add(Accuracy()) metric.add(F1(average='micro')) super(MRPCTask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for MRPC. Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'train', 'dev', 'test'. """ return GlueMRPC(segment=segment)
[docs]class QQPTask(GlueTask): """The Quora Question Pairs task on GlueBenchmark. Examples -------- >>> QQP = QQPTask() >>> QQP.class_labels ['0', '1'] >>> type(QQP.metrics.get_metric(0)) <class 'mxnet.metric.Accuracy'> >>> type(QQP.metrics.get_metric(1)) <class 'mxnet.metric.F1'> >>> import warnings >>> with warnings.catch_warnings(): ... # Ignore warnings triggered by invalid entries in GlueQQP set ... warnings.simplefilter("ignore") ... QQP.dataset_train()[0] -etc- 'train' >>> QQP.dataset_test()[0] -etc- 'test' >>> len(QQP.dataset_test()[1]) 390965 """ def __init__(self): is_pair = True class_labels = ['0', '1'] metric = CompositeEvalMetric() metric.add(Accuracy()) metric.add(F1(average='micro')) super(QQPTask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for QQP. Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'train', 'dev', 'test'. """ return GlueQQP(segment=segment)
[docs]class RTETask(GlueTask): """The Recognizing Textual Entailment task on GlueBenchmark. Examples -------- >>> RTE = RTETask() >>> RTE.class_labels ['not_entailment', 'entailment'] >>> type(RTE.metrics) <class 'mxnet.metric.Accuracy'> >>> RTE.dataset_train()[0] -etc- 'train' >>> len(RTE.dataset_train()[1]) 2490 >>> RTE.dataset_dev()[0] -etc- 'dev' >>> len(RTE.dataset_dev()[1]) 277 >>> RTE.dataset_test()[0] -etc- 'test' >>> len(RTE.dataset_test()[1]) 3000 """ def __init__(self): is_pair = True class_labels = ['not_entailment', 'entailment'] metric = Accuracy() super(RTETask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for RTE. Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'train', 'dev', 'test'. """ return GlueRTE(segment=segment)
[docs]class QNLITask(GlueTask): """The SQuAD NLI task on GlueBenchmark. Examples -------- >>> QNLI = QNLITask() >>> QNLI.class_labels ['not_entailment', 'entailment'] >>> type(QNLI.metrics) <class 'mxnet.metric.Accuracy'> >>> QNLI.dataset_train()[0] -etc- 'train' >>> len(QNLI.dataset_train()[1]) 108436 >>> QNLI.dataset_dev()[0] -etc- 'dev' >>> len(QNLI.dataset_dev()[1]) 5732 >>> QNLI.dataset_test()[0] -etc- 'test' >>> len(QNLI.dataset_test()[1]) 5740 """ def __init__(self): is_pair = True class_labels = ['not_entailment', 'entailment'] metric = Accuracy() super(QNLITask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for QNLI. Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'train', 'dev', 'test'. """ return GlueQNLI(segment=segment)
[docs]class STSBTask(GlueTask): """The Sentence Textual Similarity Benchmark task on GlueBenchmark. Examples -------- >>> STSB = STSBTask() >>> STSB.class_labels >>> type(STSB.metrics) <class 'mxnet.metric.PearsonCorrelation'> >>> STSB.dataset_train()[0] -etc- 'train' >>> len(STSB.dataset_train()[1]) 5749 >>> STSB.dataset_dev()[0] -etc- 'dev' >>> len(STSB.dataset_dev()[1]) 1500 >>> STSB.dataset_test()[0] -etc- 'test' >>> len(STSB.dataset_test()[1]) 1379 """ def __init__(self): is_pair = True class_labels = None metric = PearsonCorrelation(average='micro') super(STSBTask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for STSB Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'train', 'dev', 'test'. """ return GlueSTSB(segment=segment)
[docs]class CoLATask(GlueTask): """The Warstdadt acceptability task on GlueBenchmark. Examples -------- >>> CoLA = CoLATask() >>> CoLA.class_labels ['0', '1'] >>> type(CoLA.metrics) <class 'mxnet.metric.MCC'> >>> CoLA.dataset_train()[0] -etc- 'train' >>> len(CoLA.dataset_train()[1]) 8551 >>> CoLA.dataset_dev()[0] -etc- 'dev' >>> len(CoLA.dataset_dev()[1]) 1043 >>> CoLA.dataset_test()[0] -etc- 'test' >>> len(CoLA.dataset_test()[1]) 1063 """ def __init__(self): is_pair = False class_labels = ['0', '1'] metric = MCC(average='micro') super(CoLATask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for CoLA Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'train', 'dev', 'test'. """ return GlueCoLA(segment=segment)
[docs]class SSTTask(GlueTask): """The Stanford Sentiment Treebank task on GlueBenchmark. Examples -------- >>> SST = SSTTask() >>> SST.class_labels ['0', '1'] >>> type(SST.metrics) <class 'mxnet.metric.Accuracy'> >>> SST.dataset_train()[0] -etc- 'train' >>> len(SST.dataset_train()[1]) 67349 >>> SST.dataset_dev()[0] -etc- 'dev' >>> len(SST.dataset_dev()[1]) 872 >>> SST.dataset_test()[0] -etc- 'test' >>> len(SST.dataset_test()[1]) 1821 """ def __init__(self): is_pair = False class_labels = ['0', '1'] metric = Accuracy() super(SSTTask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for SST Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'train', 'dev', 'test'. """ return GlueSST2(segment=segment)
[docs]class WNLITask(GlueTask): """The Winograd NLI task on GlueBenchmark. Examples -------- >>> WNLI = WNLITask() >>> WNLI.class_labels ['0', '1'] >>> type(WNLI.metrics) <class 'mxnet.metric.Accuracy'> >>> WNLI.dataset_train()[0] -etc- 'train' >>> len(WNLI.dataset_train()[1]) 635 >>> WNLI.dataset_dev()[0] -etc- 'dev' >>> len(WNLI.dataset_dev()[1]) 71 >>> WNLI.dataset_test()[0] -etc- 'test' >>> len(WNLI.dataset_test()[1]) 146 """ def __init__(self): is_pair = True class_labels = ['0', '1'] metric = Accuracy() super(WNLITask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for WNLI Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'dev', 'test', 'train' """ return GlueWNLI(segment=segment)
[docs]class MNLITask(GlueTask): """The Multi-Genre Natural Language Inference task on GlueBenchmark. Examples -------- >>> MNLI = MNLITask() >>> MNLI.class_labels ['neutral', 'entailment', 'contradiction'] >>> type(MNLI.metrics) <class 'mxnet.metric.Accuracy'> >>> MNLI.dataset_train()[0] -etc- 'train' >>> len(MNLI.dataset_train()[1]) 392702 >>> MNLI.dataset_dev()[0][0] -etc- 'dev_matched' >>> len(MNLI.dataset_dev()[0][1]) 9815 >>> MNLI.dataset_dev()[1][0] 'dev_mismatched' >>> len(MNLI.dataset_dev()[1][1]) 9832 >>> MNLI.dataset_test()[0][0] -etc- 'test_matched' >>> len(MNLI.dataset_test()[0][1]) 9796 >>> MNLI.dataset_test()[1][0] 'test_mismatched' >>> len(MNLI.dataset_test()[1][1]) 9847 """ def __init__(self): is_pair = True class_labels = ['neutral', 'entailment', 'contradiction'] metric = Accuracy() super(MNLITask, self).__init__(class_labels, metric, is_pair)
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for MNLI Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'dev_matched', 'dev_mismatched', 'test_matched', 'test_mismatched', 'train' """ return GlueMNLI(segment=segment)
[docs] def dataset_dev(self): """Get the dev segment of the dataset for the task. Returns ------- list of TSVDataset : the dataset of the dev segment. """ return [('dev_matched', self.get_dataset(segment='dev_matched')), ('dev_mismatched', self.get_dataset(segment='dev_mismatched'))]
[docs] def dataset_test(self): """Get the test segment of the dataset for the task. Returns ------- list of TSVDataset : the dataset of the test segment. """ return [('test_matched', self.get_dataset(segment='test_matched')), ('test_mismatched', self.get_dataset(segment='test_mismatched'))]
[docs]class XNLITask(GlueTask): """The XNLI task using the dataset released from Baidu <https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE>. Examples -------- >>> XNLI = XNLITask() >>> XNLI.class_labels ['neutral', 'entailment', 'contradiction'] >>> type(XNLI.metrics) <class 'mxnet.metric.Accuracy'> >>> XNLI.dataset_train()[0] 'train' >>> len(XNLI.dataset_train()[1]) 392702 >>> XNLI.dataset_dev()[0] 'dev' >>> len(XNLI.dataset_dev()[1]) 2490 >>> XNLI.dataset_test()[0] 'test' >>> len(XNLI.dataset_test()[1]) 5010 """ def __init__(self): is_pair = True class_labels = ['neutral', 'entailment', 'contradiction'] metric = Accuracy() super(XNLITask, self).__init__(class_labels, metric, is_pair, label_alias={'contradictory':'contradiction'})
[docs] def get_dataset(self, segment='train'): """Get the corresponding dataset for XNLI. Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'dev', 'test', 'train' """ return BaiduErnieXNLI(segment)
class LCQMCTask(GlueTask): """The LCQMC task. Note that this dataset is no longer public. You can apply to the dataset owners for LCQMC. http://icrc.hitsz.edu.cn/info/1037/1146.htm """ def __init__(self): is_pair = True class_labels = ['0', '1'] metric = Accuracy() super(LCQMCTask, self).__init__(class_labels, metric, is_pair) def get_dataset(self, file_path, segment='train'): # pylint: disable=arguments-differ """Get the corresponding dataset for LCQMC. Parameters ---------- file_path : str Path to the dataset file segment : str, default 'train' Dataset segments. Options are 'dev', 'test', 'train' """ return BaiduErnieLCQMC(file_path, segment) class ChnSentiCorpTask(GlueTask): """The ChnSentiCorp task using the dataset released from Baidu <https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE>. Examples -------- >>> ChnSentiCorp = ChnSentiCorpTask() >>> ChnSentiCorp.class_labels ['0', '1'] >>> type(ChnSentiCorp.metrics) <class 'mxnet.metric.Accuracy'> >>> ChnSentiCorp.dataset_train()[0] 'train' >>> len(ChnSentiCorp.dataset_train()[1]) 9600 >>> ChnSentiCorp.dataset_dev()[0] 'dev' >>> len(ChnSentiCorp.dataset_dev()[1]) 1200 >>> ChnSentiCorp.dataset_test()[0] 'test' >>> len(ChnSentiCorp.dataset_test()[1]) 1200 """ def __init__(self): is_pair = False class_labels = ['0', '1'] metric = Accuracy() super(ChnSentiCorpTask, self).__init__(class_labels, metric, is_pair) def get_dataset(self, segment='train'): """Get the corresponding dataset for ChnSentiCorp. Parameters ---------- segment : str, default 'train' Dataset segments. Options are 'dev', 'test', 'train' """ return BaiduErnieChnSentiCorp(segment)
[docs]def get_task(task): """Returns a pre-defined glue task by name. Parameters ---------- task : str Options include 'MRPC', 'QNLI', 'RTE', 'STS-B', 'CoLA', 'MNLI', 'WNLI', 'SST', 'XNLI', 'LCQMC', 'ChnSentiCorp' Returns ------- GlueTask """ tasks = { 'mrpc': MRPCTask(), 'qqp': QQPTask(), 'qnli': QNLITask(), 'rte': RTETask(), 'sts-b': STSBTask(), 'cola': CoLATask(), 'mnli': MNLITask(), 'wnli': WNLITask(), 'sst': SSTTask(), 'xnli': XNLITask(), 'lcqmc': LCQMCTask(), 'chnsenticorp': ChnSentiCorpTask() } if task.lower() not in tasks: raise ValueError( 'Task name %s is not supported. Available options are\n\t%s'%( task, '\n\t'.join(sorted(tasks.keys())))) return copy(tasks[task.lower()])