softmax_with_cross_entropy¶
- paddle.nn.functional. softmax_with_cross_entropy ( logits, label, soft_label=False, ignore_index=- 100, numeric_stable_mode=True, return_softmax=False, axis=- 1 ) [source]
-
Warning
API “paddle.nn.functional.loss.softmax_with_cross_entropy” is deprecated since 2.0.0, and will be removed in future versions. Please use “paddle.nn.functional.cross_entropy” instead. Reason: Please notice that behavior of “paddle.nn.functional.softmax_with_cross_entropy” and “paddle.nn.functional.cross_entropy” is different.
This operator implements the cross entropy loss function with softmax. This function combines the calculation of the softmax operation and the cross entropy loss function to provide a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects unscaled logits. This operator should not be used with the output of softmax operator since that would produce incorrect results.
When the attribute
soft_label
is setFalse
, this operators expects mutually exclusive hard labels, each sample in a batch is in exactly one class with a probability of 1.0. Each sample in the batch will have a single label.The equation is as follows:
Hard label (one-hot label, so every sample has exactly one class)
\[\begin{split}\\loss_j=-\text{logits}_{label_j} +\log\left(\sum_{i=0}^{K}\exp(\text{logits}_i)\right), j = 1,..., K\end{split}\]Soft label (each sample can have a distribution over all classes)
\[\begin{split}\\loss_j= -\sum_{i=0}^{K}\text{label}_i\left(\text{logits}_i - \log\left(\sum_{i=0}^{K}\exp(\text{logits}_i)\right)\right), j = 1,...,K\end{split}\]If
numeric_stable_mode
isTrue
, softmax is calculated first by:
\[\begin{split}\\max_j&=\max_{i=0}^{K}{\text{logits}_i} \\ log\_max\_sum_j &= \log\sum_{i=0}^{K}\exp(logits_i - max_j)\\ softmax_j &= \exp(logits_j - max_j - {log\_max\_sum}_j)\end{split}\]and then cross entropy loss is calculated by softmax and label.
- Parameters
-
logits (Tensor) – A multi-dimension
Tensor
, and the data type is float32 or float64. The input tensor of unscaled log probabilities.label (Tensor) – The ground truth
Tensor
, data type is the same as thelogits
. Ifsoft_label
is set toTrue
, Label is aTensor
in the same shape withlogits
. Ifsoft_label
is set toTrue
, Label is aTensor
in the same shape withlogits
expect shape in dimensionaxis
as 1.soft_label (bool, optional) – A flag to indicate whether to interpret the given labels as soft labels. Default False.
ignore_index (int, optional) – Specifies a target value that is ignored and does not contribute to the input gradient. Only valid if
soft_label
is set toFalse
. Default: kIgnoreIndex(-100).numeric_stable_mode (bool, optional) – A flag to indicate whether to use a more numerically stable algorithm. Only valid when
soft_label
isFalse
and GPU is used. Whensoft_label
isTrue
or CPU is used, the algorithm is always numerically stable. Note that the speed may be slower when use stable algorithm. Default: True.return_softmax (bool, optional) – A flag indicating whether to return the softmax along with the cross entropy loss. Default: False.
axis (int, optional) – The index of dimension to perform softmax calculations. It should be in range \([-1, rank - 1]\), while \(rank\) is the rank of input
logits
. Default: -1.
- Returns
-
If return_softmax is False, return the cross entropy loss as a
Tensor
. The dtype is the same as the inputlogits
. The shape is consistent withlogits
except in dimensionaxis
as 1.If return_softmax is True, return a tuple of two
Tensor
: the cross entropy loss and the softmax result. The dtype of the cross entropy loss is the same as the inputlogits
, and the shape is consistent withlogits
except in dimensionaxis
as 1. The dtype and shape of the softmax result are the same as the inputlogits
.
Examples
>>> import paddle >>> logits = paddle.to_tensor([0.4, 0.6, 0.9], dtype="float32") >>> label = paddle.to_tensor([1], dtype="int64") >>> out = paddle.nn.functional.softmax_with_cross_entropy(logits=logits, label=label) >>> print(out) Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, [1.15328646])