shard_index

paddle. shard_index ( input, index_num, nshards, shard_id, ignore_value=- 1 ) [源代码]

根据当前 shard 重新设置输入参数input的值。输入input中的值需要为非负整型;参数index_num为用户设置的大于input最大值的整型值。因此,input中的值属于区间[0, index_num),且每个值可以被看作到区间起始的偏移量。区间可以被进一步划分为多个切片。具体地讲,我们首先根据下面的公式计算每个切片的大小:shard_size,表示每个切片可以表示的整数的数量。因此,对于第i个切片,其表示的区间为[i*shard_size, (i+1)*shard_size)。

shard_size = (index_num + nshards - 1) // nshards

对于输入input中的每个值v,我们根据下面的公式设置它新的值:

v = v - shard_id * shard_size if shard_id * shard_size <= v < (shard_id+1) * shard_size else ignore_value

参数

  • input (Tensor) - 输入 tensor,最后一维的维度值为 1,数据类型为 int64 或 int32。

  • index_num (int) - 用户设置的大于 input 最大值的整型值。

  • nshards (int) - 分片数量。

  • shard_id (int) - 当前分片 ID。

  • ignore_value (int) - 超出分片范围的默认值。

返回

Tensor

代码示例

import paddle
label = paddle.to_tensor([[16], [1]], "int64")
shard_label = paddle.shard_index(input=label,
                                 index_num=20,
                                 nshards=2,
                                 shard_id=0)
print(shard_label)
# [[-1], [1]]