log_softmax

paddle.nn.functional. log_softmax ( x, axis=- 1, dtype=None, name=None ) [源代码]

实现了 log_softmax 层。计算公式如下:

\[\begin{split}\begin{aligned} log\_softmax[i, j] &= log(softmax(x)) \\ &= log(\frac{\exp(X[i, j])}{\sum_j(\exp(X[i, j])}) \end{aligned}\end{split}\]

参数

  • x (Tensor) - 输入的 Tensor,数据类型为:float32、float64。

  • axis (int,可选) - 指定对输入 x 进行运算的轴。axis 的有效范围是[-D, D),D 是输入 x 的维度,axis 为负值时与 \(axis + D\) 等价。默认值为-1。

  • dtype (str|np.dtype|core.VarDesc.VarType,可选) - 输入 Tensor 的数据类型。如果指定了 dtype,则输入 Tensor 的数据类型会在计算前转换到 dtypedtype 可以用来避免数据溢出。如果 dtype 为 None,则输出 Tensor 的数据类型和 x 相同。默认值为 None。

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

返回

Tensor,形状和 x 相同,数据类型为 dtype 或者和 x 相同。

代码示例

>>> import paddle
>>> import paddle.nn.functional as F
>>> x = [[[-2.0, 3.0, -4.0, 5.0],
...       [3.0, -4.0, 5.0, -6.0],
...       [-7.0, -8.0, 8.0, 9.0]],
...      [[1.0, -2.0, -3.0, 4.0],
...       [-5.0, 6.0, 7.0, -8.0],
...       [6.0, 7.0, 8.0, 9.0]]]
>>> x = paddle.to_tensor(x)
>>> out1 = F.log_softmax(x)
>>> out2 = F.log_softmax(x, dtype='float64')
>>> #out1's data type is float32; out2's data type is float64
>>> #out1 and out2's value is as follows:
>>> print(out1)
Tensor(shape=[2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[-7.12783957 , -2.12783957 , -9.12783909 , -0.12783945 ],
  [-2.12705135 , -9.12705135 , -0.12705141 , -11.12705135],
  [-16.31326103, -17.31326103, -1.31326187 , -0.31326184 ]],
 [[-3.05181193 , -6.05181217 , -7.05181217 , -0.05181199 ],
  [-12.31326675, -1.31326652 , -0.31326646 , -15.31326675],
  [-3.44018984 , -2.44018984 , -1.44018972 , -0.44018975 ]]])
>>> print(out2)
Tensor(shape=[2, 3, 4], dtype=float64, place=Place(cpu), stop_gradient=True,
[[[-7.12783948 , -2.12783948 , -9.12783948 , -0.12783948 ],
  [-2.12705141 , -9.12705141 , -0.12705141 , -11.12705141],
  [-16.31326180, -17.31326180, -1.31326180 , -0.31326180 ]],
 [[-3.05181198 , -6.05181198 , -7.05181198 , -0.05181198 ],
  [-12.31326640, -1.31326640 , -0.31326640 , -15.31326640],
  [-3.44018970 , -2.44018970 , -1.44018970 , -0.44018970 ]]])