Source code for chainer.datasets.ptb

import os

import numpy

from chainer.dataset import download


[docs]def get_ptb_words(): """Gets the Penn Tree Bank dataset as long word sequences. `Penn Tree Bank <https://www.cis.upenn.edu/~treebank/>`_ is originally a corpus of English sentences with linguistic structure annotations. This function uses a variant distributed at `https://github.com/tomsercu/lstm <https://github.com/tomsercu/lstm>`_, which omits the annotation and splits the dataset into three parts: training, validation, and test. This function returns the training, validation, and test sets, each of which is represented as a long array of word IDs. All sentences in the dataset are concatenated by End-of-Sentence mark '<eos>', which is treated as one of the vocabulary. Returns: tuple of numpy.ndarray: Int32 vectors of word IDs. .. Seealso:: Use :func:`get_ptb_words_vocabulary` to get the mapping between the words and word IDs. """ train = _retrieve_ptb_words('train.npz', _train_url) valid = _retrieve_ptb_words('valid.npz', _valid_url) test = _retrieve_ptb_words('test.npz', _test_url) return train, valid, test
[docs]def get_ptb_words_vocabulary(): """Gets the Penn Tree Bank word vocabulary. Returns: dict: Dictionary that maps words to corresponding word IDs. The IDs are used in the Penn Tree Bank long sequence datasets. .. seealso:: See :func:`get_ptb_words` for the actual datasets. """ return _retrieve_word_vocabulary()
_train_url = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt' # NOQA _valid_url = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt' # NOQA _test_url = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt' # NOQA def _retrieve_ptb_words(name, url): def creator(path): vocab = _retrieve_word_vocabulary() words = _load_words(url) x = numpy.empty(len(words), dtype=numpy.int32) for i, word in enumerate(words): x[i] = vocab[word] numpy.savez_compressed(path, x=x) return {'x': x} root = download.get_dataset_directory('pfnet/chainer/ptb') path = os.path.join(root, name) loaded = download.cache_or_load_file(path, creator, numpy.load) return loaded['x'] def _retrieve_word_vocabulary(): def creator(path): words = _load_words(_train_url) vocab = {} index = 0 with open(path, 'w') as f: for word in words: if word not in vocab: vocab[word] = index index += 1 f.write(word + '\n') return vocab def loader(path): vocab = {} with open(path) as f: for i, word in enumerate(f): vocab[word.strip()] = i return vocab root = download.get_dataset_directory('pfnet/chainer/ptb') path = os.path.join(root, 'vocab.txt') return download.cache_or_load_file(path, creator, loader) def _load_words(url): path = download.cached_download(url) words = [] with open(path) as words_file: for line in words_file: if line: words += line.strip().split() words.append('<eos>') return words