WMT14¶
该类是对 WMT14 测试数据集实现。 由于原始WMT14数据集太大,我们在这里提供了一组小数据集。该类将从 http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz 下载数据集。
参数¶
返回值¶
Dataset
,WMT14数据集实例。
src_ids (np.array) - 源语言当前的token id序列。
trg_ids (np.array) - 目标语言当前的token id序列。
trg_ids_next (np.array) - 目标语言下一段的token id序列。
代码示例¶
import paddle
from paddle.text.datasets import WMT14
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, src_ids, trg_ids, trg_ids_next):
return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
wmt14 = WMT14(mode='train', dict_size=50)
for i in range(10):
src_ids, trg_ids, trg_ids_next = wmt14[i]
src_ids = paddle.to_tensor(src_ids)
trg_ids = paddle.to_tensor(trg_ids)
trg_ids_next = paddle.to_tensor(trg_ids_next)
model = SimpleNet()
src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy())