WMT14¶
- class paddle.text. WMT14 ( data_file=None, mode='train', dict_size=- 1, download=True ) [source]
-
Implementation of WMT14 test dataset. The original WMT14 dataset is too large and a small set of data for set is provided. This module will download dataset from http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz .
- Parameters
-
data_file (str) – path to data tar file, can be set None if
download
is True. Default Nonemode (str) – ‘train’, ‘test’ or ‘gen’. Default ‘train’
dict_size (int) – word dictionary size. Default -1.
download (bool) – whether to download dataset automatically if
data_file
is not set. Default True
- Returns
-
- Instance of WMT14 dataset
-
src_ids (np.array) - The sequence of token ids of source language.
trg_ids (np.array) - The sequence of token ids of target language.
trg_ids_next (np.array) - The next sequence of token ids of target language.
- Return type
-
Dataset
Examples
>>> import paddle >>> from paddle.text.datasets import WMT14 >>> class SimpleNet(paddle.nn.Layer): ... def __init__(self): ... super().__init__() ... ... def forward(self, src_ids, trg_ids, trg_ids_next): ... return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next) >>> wmt14 = WMT14(mode='train', dict_size=50) >>> for i in range(10): ... src_ids, trg_ids, trg_ids_next = wmt14[i] ... src_ids = paddle.to_tensor(src_ids) ... trg_ids = paddle.to_tensor(trg_ids) ... trg_ids_next = paddle.to_tensor(trg_ids_next) ... ... model = SimpleNet() ... src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next) ... print(src_ids.item(), trg_ids.item(), trg_ids_next.item()) 91 38 39 123 81 82 556 229 230 182 26 27 447 242 243 116 110 111 403 288 289 258 221 222 136 34 35 281 136 137
-
get_dict
(
reverse=False
)
get_dict¶
-
Get the source and target dictionary.
- Parameters
-
reverse (bool) – wether to reverse key and value in dictionary, i.e. key: value to value: key.
- Returns
-
Two dictionaries, the source and target dictionary.
Examples
>>> from paddle.text.datasets import WMT14 >>> wmt14 = WMT14(mode='train', dict_size=50) >>> src_dict, trg_dict = wmt14.get_dict()