einsum

paddle. einsum ( equation, *operands ) [源代码]

该函数用于对一组输入张量进行 Einstein 求和,该函数目前仅适用于动态图

Einstein 求和是一种采用 Einstein 标记法描述的张量求和,输入单个或多个张量,输出单个张量。

如下的张量操作或运算均可视为 Einstein 求和的特例

  • 单操作数
    • 迹:trace

    • 对角元:diagonal

    • 转置:transpose

    • 求和:sum

  • 双操作数
    • 内积:dot

    • 外积:outer

    • 广播乘积:mul,*

    • 矩阵乘:matmul

    • 批量矩阵乘:bmm

  • 多操作数
    • 广播乘积:mul,*

    • 多矩阵乘:A.matmul(B).matmul(C)

关于求和标记的约定

  • 维度分量下标:张量的维度分量下标使用英文字母表示,不区分大小写,如'ijk'表示张量维度分量为 i,j,k

  • 下标对应输入操作数:维度下标以`,`分段,按顺序 1-1 对应输入操作数

  • 广播维度:省略号`...`表示维度的广播分量,例如,'i...j'表示首末分量除外的维度需进行广播对齐

  • 自由标和哑标:输入标记中仅出现一次的下标为自由标,重复出现的下标为哑标,哑标对应的维度分量将被规约消去

  • 输出:输出张量的维度分量既可由输入标记自动推导,也可以用输出标记定制化
    • 自动推导输出
      • 广播维度分量位于维度向量高维位置,自由标维度分量按字母顺序排序,位于维度向量低纬位置,哑标维度分量不输出

    • 定制化输出
      • 维度标记中`->`右侧为输出标记

      • 若输出包含广播维度,则输出标记需包含`...`

      • 输出标记为空时,对输出进行全量求和,返回该标量

      • 输出不能包含输入标记中未出现的下标

      • 输出下标不可以重复出现

      • 哑标出现在输出标记中则自动提升为自由标

      • 输出标记中未出现的自由标被降为哑标

  • 例子
    • '...ij, ...jk',该标记中 i,k 为自由标,j 为哑标,输出维度'...ik'

    • 'ij -> i',i 为自由标,j 为哑标

    • '...ij, ...jk -> ...ijk',i,j,k 均为自由标

    • '...ij, ...jk -> ij',若输入张量中的广播维度不为空,则该标记为无效标记

求和规则

Einsum 求和过程理论上等价于如下四步,但实现中实际执行的步骤会有差异。

  • 第一步,维度对齐:将所有标记按字母序排序,按照标记顺序将输入张量逐一转置、补齐维度,使得处理后的所有张量其维度标记保持一致

  • 第二步,广播乘积:以维度下标为索引进行广播点乘

  • 第三步,维度规约:将哑标对应的维度分量求和消除

  • 第四步,转置输出:若存在输出标记,则按标记进行转置,否则按广播维度+字母序自由标的顺序转置,返回转之后的张量作为输出

关于 trace 和 diagonal 的标记约定(待实现功能)

  • 在单个输入张量的标记中重复出现的下标称为对角标,对角标对应的坐标轴需进行对角化操作,如'i...i'表示需对首尾坐标轴进行对角化

  • 若无输出标记或输出标记中不包含对角标,则对角标对应维度规约为标量,相应维度取消,等价于 trace 操作

  • 若输出标记中包含对角标,则保留对角标维度,等价于 diagonal 操作

参数

equation (str):求和标记

operands (Tensor, [Tensor, ...]):输入张量

返回

Tensor:输出张量

代码示例

import paddle
paddle.seed(102)
x = paddle.rand([4])
y = paddle.rand([5])

# sum
print(paddle.einsum('i->', x))
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#   1.95791852)

# dot
print(paddle.einsum('i,i->', x, x))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#   [1.45936954])

# outer
print(paddle.einsum("i,j->ij", x, y))
# Tensor(shape=[4, 5], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#   [[0.00079869, 0.00120950, 0.00136844, 0.00187187, 0.00192194],
#    [0.23455200, 0.35519385, 0.40186870, 0.54970956, 0.56441545],
#    [0.11773264, 0.17828843, 0.20171674, 0.27592498, 0.28330654],
#    [0.32897076, 0.49817693, 0.56364071, 0.77099484, 0.79162055]])

A = paddle.rand([2, 3, 2])
B = paddle.rand([2, 2, 3])

# transpose
print(paddle.einsum('ijk->kji', A))
#  Tensor(shape=[2, 3, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#   [[[0.95649719, 0.49684682],
#     [0.80071914, 0.46258664],
#     [0.49814570, 0.33383518]],
#
#    [[0.07637714, 0.29374704],
#     [0.51470858, 0.51907635],
#     [0.99066722, 0.55802226]]])

# batch matrix multiplication
print(paddle.einsum('ijk, ikl->ijl', A,B))
# Tensor(shape=[2, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#   [[[0.32172769, 0.50617385, 0.41394392],
#     [0.51736701, 0.49921003, 0.38730967],
#     [0.69078457, 0.42282537, 0.30161136]],
#
#    [[0.32043904, 0.18164253, 0.27810261],
#     [0.50226176, 0.24512935, 0.39881429],
#     [0.51476848, 0.23367381, 0.39229113]]])

# Ellipsis transpose
print(paddle.einsum('...jk->...kj', A))
# Tensor(shape=[2, 2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#   [[[0.95649719, 0.80071914, 0.49814570],
#     [0.07637714, 0.51470858, 0.99066722]],
#
#    [[0.49684682, 0.46258664, 0.33383518],
#     [0.29374704, 0.51907635, 0.55802226]]])

# Ellipsis batch matrix multiplication
print(paddle.einsum('...jk, ...kl->...jl', A,B))
# Tensor(shape=[2, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#   [[[0.32172769, 0.50617385, 0.41394392],
#     [0.51736701, 0.49921003, 0.38730967],
#     [0.69078457, 0.42282537, 0.30161136]],
#
#    [[0.32043904, 0.18164253, 0.27810261],
#     [0.50226176, 0.24512935, 0.39881429],
#     [0.51476848, 0.23367381, 0.39229113]]])