unflatten¶
将输入 Tensor 沿指定轴 axis 上的维度展成 shape 形状。与 flatten 是反函数。
参数¶
x (Tensor) - 输入多维 Tensor,可选的数据类型为 'float16'、'float32'、'float64'、'int16'、'int32'、'int64'、'bool'、'uint16'。
axis (int) - 要展开维度的轴,作为
x.shape
的索引。shape (list|tuple|Tensor) - 在指定轴上将该维度展成
shape
, 其中shape
最多包含一个 -1,如果输入shape
不包含 -1 ,则shape
内元素累乘的结果应该等于x.shape[axis]
。数据类型为int
。如果shape
的类型是list
或tuple
,它的元素可以是整数或者形状为[]的Tensor
(0-DTensor
)。如果shape
的类型是Tensor
,则是 1-D 的Tensor
。name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。
返回¶
Tensor,沿指定轴将维度展开的后的 x
。
代码示例¶
import paddle
x = paddle.randn(shape=[4, 6, 8])
shape = [2, 3]
axis = 1
res = paddle.unflatten(x, axis, shape)
print(res.shape)
# [4, 2, 3, 8]
x = paddle.randn(shape=[4, 6, 8])
shape = (-1, 2)
axis = -1
res = paddle.unflatten(x, axis, shape)
print(res.shape)
# [4, 6, 4, 2]
x = paddle.randn(shape=[4, 6, 8])
shape = paddle.to_tensor([2, 2])
axis = 0
res = paddle.unflatten(x, axis, shape)
print(res.shape)
# [2, 2, 6, 8]