dsplit

paddle. dsplit ( x, num_or_indices, name=None ) [源代码]

dsplit 全称 Depth Split ,即深度分割,将输入 Tensor 沿着深度轴分割成多个子 Tensor,等价于将 tensor_split API 的参数 axis 固定为 2。

注解

请确保使用 paddle.dsplit 进行变换的 Tensor 维度数量不少于 3。

如下图,Tenser x 的 shape 为[4, 4, 4],经过 paddle.dsplit(x, num_or_indices=2) 变换后,得到 out0out1 两个 shape 均为[4, 4, 2]的子 Tensor :

dsplit 图例

参数

  • x (Tensor) - 输入变量,数据类型为 bool、bfloat16、float16、float32、float64、uint8、int8、int32、int64 的多维 Tensor,其维度必须大于 2。

  • num_or_indices (int|list|tuple) - 如果 num_or_indices 是一个整数 n ,则 x 拆分为 n 部分。如果 num_or_indices 是整数索引的列表或元组,则在每个索引处分割 x

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

返回

list[Tensor],分割后的 Tensor 列表。

代码示例

>>> import paddle

>>> # x is a Tensor of shape [7, 6, 8]
>>> x = paddle.rand([7, 6, 8])
>>> out0, out1 = paddle.dsplit(x, num_or_indices=2)
>>> print(out0.shape)
[7, 6, 4]
>>> print(out1.shape)
[7, 6, 4]

>>> out0, out1, out2 = paddle.dsplit(x, num_or_indices=[1, 4])
>>> print(out0.shape)
[7, 6, 1]
>>> print(out1.shape)
[7, 6, 3]
>>> print(out2.shape)
[7, 6, 4]