Imdb

class paddle.text. Imdb ( data_file=None, mode='train', cutoff=150, download=True ) [源代码]

该类是对 IMDB 测试数据集的实现。

参数

  • data_file (str) - 保存压缩数据的路径,如果参数 :attr:`download`设置为 True,可设置为 None。默认为 None。

  • mode (str) - 'train' 或'test' 模式。默认为'train'。

  • cutoff (int) - 构建词典的截止大小。默认为 Default 150。

  • download (bool) - 如果 :attr:`data_file`未设置,是否自动下载数据集。默认为 True。

返回

Dataset, IMDB 数据集实例。

代码示例

>>> import paddle
>>> from paddle.text.datasets import Imdb

>>> class SimpleNet(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...
...     def forward(self, doc, label):
...         return paddle.sum(doc), label


>>> imdb = Imdb(mode='train')

>>> for i in range(10):
...     doc, label = imdb[i]
...     doc = paddle.to_tensor(doc)
...     label = paddle.to_tensor(label)
...
...     model = SimpleNet()
...     image, label = model(doc, label)
...     print(doc.shape, label.shape)
[121] [1]
[115] [1]
[386] [1]
[471] [1]
[585] [1]
[206] [1]
[221] [1]
[324] [1]
[166] [1]
[598] [1]