# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing,
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations

# pylint: disable=undefined-all-variable
"""Data transforms useful for language models."""

__all__ = ['CorpusBatchify', 'CorpusBPTTBatchify', 'StreamBPTTBatchify']

import itertools
import math

import numpy as np
import mxnet as mx
from mxnet.gluon.data import RandomSampler, SequentialSampler, SimpleDataset

[docs]class CorpusBPTTBatchify: """Transform the dataset into batches of numericalized samples, in the way that the recurrent states from last batch connects with the current batch for each sample. Each sample is of shape (seq_len, batch_size). When last_batch='keep', the first dimension of last sample may be shorter than seq_len. Parameters ---------- vocab : gluonnlp.Vocab The vocabulary to use for numericalizing the dataset. Each token will be mapped to the index according to the vocabulary. seq_len : int The length of each of the samples for truncated back-propagation-through-time (TBPTT). batch_size : int The number of samples in each batch. last_batch : {'keep', 'discard'} How to handle the last batch if the remaining length is less than seq_len. - keep: A batch with less samples than previous batches is returned. vocab.padding_token is used to pad the last batch based on batch size. - discard: The last batch is discarded if it's smaller than (seq_len, batch_size). """ def __init__(self, vocab, seq_len, batch_size, last_batch='keep'): self._vocab = vocab self._seq_len = seq_len self._batch_size = batch_size self._last_batch = last_batch if last_batch not in ['keep', 'discard']: raise ValueError( 'Got invalid last_batch: "{}". Must be "keep" or "discard".'. format(last_batch)) if self._last_batch == 'keep': if not self._vocab.padding_token: raise ValueError('vocab.padding_token must be specified ' 'in vocab when last_batch="keep".')
[docs]class StreamBPTTBatchify: """Transform a Stream of CorpusDataset to BPTT batches. The corpus is transformed into batches of numericalized samples, in the way that the recurrent states from last batch connects with the current batch for each sample. Each sample is of shape (seq_len, batch_size). For example, the following 4 sequences:: a b c d <eos> e f g h i j <eos> k l m n <eos> o <eos> will generate 2 batches with seq_len = 5, batch_size = 2 as follow (transposed): batch_0.data.T:: a b c d <eos> e f g h i batch_0.target.T:: b c d <eos> k f g h i j batch_1.data.T:: k l m n <eos> j <eos> o <eos> <padding> batch_1.target.T:: l m n <eos> <padding> <eos> o <eos> <padding> <padding> Parameters ---------- vocab : gluonnlp.Vocab The vocabulary to use for numericalizing the dataset. Each token will be mapped to the index according to the vocabulary. seq_len : int The length of each of the samples for truncated back-propagation-through-time (TBPTT). batch_size : int The number of samples in each batch. sampler : str, {'sequential', 'random'}, defaults to 'random' The sampler used to sample texts within a file. - 'sequential': SequentialSampler - 'random': RandomSampler last_batch : {'keep', 'discard'} How to handle the last batch if the remaining length is less than seq_len. - keep: A batch with less samples than previous batches is returned. - discard: The last batch is discarded if it's smaller than (seq_len, batch_size). """ def __init__(self, vocab, seq_len, batch_size, sampler='random', last_batch='keep'): self._vocab = vocab self._seq_len = seq_len self._batch_size = batch_size self._sampler = sampler self._last_batch = last_batch if not self._vocab.padding_token: raise ValueError('Padding token must be specified in vocab for StreamBPTTBatchify.') if last_batch not in ['keep', 'discard']: raise ValueError( 'Got invalid last_batch: "{}". Must be "keep" or "discard".'. format(last_batch)) def _get_sampler(self, sampler): assert isinstance( sampler, str), 'Expected sampler to be a str, but got %s' % type(sampler) if sampler == 'random': return RandomSampler if sampler == 'sequential': return SequentialSampler raise ValueError( 'sampler must be either "random" or "sequential", but got %s' % (sampler))