# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Sentiment analysis datasets."""

__all__ = ['IMDB']

import json
import os

from import SimpleDataset
from mxnet.gluon.utils import download, check_sha1, _get_repo_file_url
from .registry import register
from .utils import _get_home_dir

[docs]@register(segment=['train', 'test', 'unsup']) class IMDB(SimpleDataset): """IMDB reviews for sentiment analysis. From Parameters ---------- segment : str, default 'train' Dataset segment. Options are 'train', 'test', and 'unsup' for unsupervised. root : str, default '$MXNET_HOME/datasets/imdb' Path to temp folder for storing data. MXNET_HOME defaults to '~/.mxnet'. """ def __init__(self, segment='train', root=os.path.join(_get_home_dir(), 'datasets', 'imdb')): self._data_file = {'train': ('train.json', '516a0ba06bca4e32ee11da2e129f4f871dff85dc'), 'test': ('test.json', '7d59bd8899841afdc1c75242815260467495b64a'), 'unsup': ('unsup.json', 'f908a632b7e7d7ecf113f74c968ef03fadfc3c6c')} root = os.path.expanduser(root) if not os.path.isdir(root): os.makedirs(root) self._root = root self._segment = segment self._get_data() super(IMDB, self).__init__(self._read_data()) def _get_data(self): data_file_name, data_hash = self._data_file[self._segment] root = self._root path = os.path.join(root, data_file_name) if not os.path.exists(path) or not check_sha1(path, data_hash): download(_get_repo_file_url('gluon/dataset/imdb', data_file_name), path=root, sha1_hash=data_hash) def _read_data(self): with open(os.path.join(self._root, self._segment+'.json')) as f: samples = json.load(f) return samples