margin_cross_entropy

paddle.nn.functional. margin_cross_entropy ( logits, label, margin1=1.0, margin2=0.5, margin3=0.0, scale=64.0, group=None, return_softmax=False, reduction='mean' ) [source]
\[L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}\]

where the \(\theta_{y_i}\) is the angle between the feature \(x\) and the representation of class \(i\). The details of ArcFace loss could be referred to https://arxiv.org/abs/1801.07698.

Hint

The API supports single GPU and multi GPU, and don’t supports CPU. For data parallel mode, set group=False. For model parallel mode, set group=None or the group instance return by paddle.distributed.new_group. And logits.shape[-1] can be different at each rank.

Parameters
  • logits (Tensor) – shape[N, local_num_classes], the output of the normalized X multiply the normalized W. The logits is shard_logits when using model parallel.

  • label (Tensor) – shape[N] or shape[N, 1], the groud truth label.

  • margin1 (float, optional) – m1 of margin loss, default value is 1.0.

  • margin2 (float, optional) – m2 of margin loss, default value is 0.5.

  • margin3 (float, optional) – m3 of margin loss, default value is 0.0.

  • scale (float, optional) – s of margin loss, default value is 64.0.

  • group (Group, optional) – The group instance return by paddle.distributed.new_group or None for global default group or False for data parallel (do not communication cross ranks). Default is None.

  • return_softmax (bool, optional) – Whether return softmax probability. Default value is False.

  • reduction (str, optional) – The candicates are 'none' | 'mean' | 'sum'. If reduction is 'mean', return the average of loss; If reduction is 'sum', return the sum of loss; If reduction is 'none', no reduction will be applied. Default value is ‘mean’.

Returns

Tensor|tuple[Tensor, Tensor], return the cross entropy loss if

return_softmax is False, otherwise the tuple (loss, softmax), softmax is shard_softmax when using model parallel, otherwise softmax is in the same shape with input logits. If reduction == None, the shape of loss is [N, 1], otherwise the shape is [].

Examples:

# required: gpu
# Single GPU
import paddle
m1 = 1.0
m2 = 0.5
m3 = 0.0
s = 64.0
batch_size = 2
feature_length = 4
num_classes = 4

label = paddle.randint(low=0, high=num_classes, shape=[batch_size], dtype='int64')

X = paddle.randn(
    shape=[batch_size, feature_length],
    dtype='float64')
X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
X = paddle.divide(X, X_l2)

W = paddle.randn(
    shape=[feature_length, num_classes],
    dtype='float64')
W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
W = paddle.divide(W, W_l2)

logits = paddle.matmul(X, W)
loss, softmax = paddle.nn.functional.margin_cross_entropy(
    logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)

print(logits)
print(label)
print(loss)
print(softmax)

#Tensor(shape=[2, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#       [[ 0.85204151, -0.55557678,  0.04994566,  0.71986042],
#        [-0.20198586, -0.35270476, -0.55182702,  0.09749021]])
#Tensor(shape=[2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
#       [2, 3])
#Tensor(shape=[2, 1], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#       [[82.37059586],
#        [12.13448420]])
#Tensor(shape=[2, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#       [[0.99978819, 0.00000000, 0.00000000, 0.00021181],
#        [0.99992995, 0.00006468, 0.00000000, 0.00000537]])
# required: distributed
# Multi GPU, test_margin_cross_entropy.py
import paddle
import paddle.distributed as dist
strategy = dist.fleet.DistributedStrategy()
dist.fleet.init(is_collective=True, strategy=strategy)
rank_id = dist.get_rank()
m1 = 1.0
m2 = 0.5
m3 = 0.0
s = 64.0
batch_size = 2
feature_length = 4
num_class_per_card = [4, 8]
num_classes = paddle.sum(paddle.to_tensor(num_class_per_card))

label = paddle.randint(low=0, high=num_classes.item(), shape=[batch_size], dtype='int64')
label_list = []
dist.all_gather(label_list, label)
label = paddle.concat(label_list, axis=0)

X = paddle.randn(
    shape=[batch_size, feature_length],
    dtype='float64')
X_list = []
dist.all_gather(X_list, X)
X = paddle.concat(X_list, axis=0)
X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
X = paddle.divide(X, X_l2)

W = paddle.randn(
    shape=[feature_length, num_class_per_card[rank_id]],
    dtype='float64')
W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
W = paddle.divide(W, W_l2)

logits = paddle.matmul(X, W)
loss, softmax = paddle.nn.functional.margin_cross_entropy(
    logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)

print(logits)
print(label)
print(loss)
print(softmax)

# python -m paddle.distributed.launch --gpus=0,1 test_margin_cross_entropy.py
## for rank0 input
#Tensor(shape=[4, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#       [[ 0.32888934,  0.02408748, -0.02763289,  0.18173063],
#        [-0.52893978, -0.10623845, -0.21596515, -0.06432517],
#        [-0.00536345, -0.03924667,  0.66735314, -0.28640926],
#        [-0.09907366, -0.48534973, -0.10365338, -0.39472322]])
#Tensor(shape=[4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
#       [11, 1 , 10, 11])

## for rank1 input
#Tensor(shape=[4, 8], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
#       [[ 0.68654754,  0.28137170,  0.69694954, -0.60923933, -0.57077653,  0.54576703, -0.38709028,  0.56028204],
#        [-0.80360371, -0.03042448, -0.45107338,  0.49559349,  0.69998950, -0.45411693,  0.61927630, -0.82808600],
#        [ 0.11457570, -0.34785879, -0.68819499, -0.26189226, -0.48241491, -0.67685711,  0.06510185,  0.49660849],
#        [ 0.31604851,  0.52087884,  0.53124749, -0.86176582, -0.43426329,  0.34786144, -0.10850784,  0.51566383]])
#Tensor(shape=[4], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
#       [11, 1 , 10, 11])

## for rank0 output
#Tensor(shape=[4, 1], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#       [[38.96608230],
#        [81.28152394],
#        [69.67229865],
#        [31.74197251]])
#Tensor(shape=[4, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
#       [[0.00000000, 0.00000000, 0.00000000, 0.00000000],
#        [0.00000000, 0.00000000, 0.00000000, 0.00000000],
#        [0.00000000, 0.00000000, 0.99998205, 0.00000000],
#        [0.00000000, 0.00000000, 0.00000000, 0.00000000]])
## for rank1 output
#Tensor(shape=[4, 1], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
#       [[38.96608230],
#        [81.28152394],
#        [69.67229865],
#        [31.74197251]])
#Tensor(shape=[4, 8], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
#       [[0.33943993, 0.00000000, 0.66051859, 0.00000000, 0.00000000, 0.00004148, 0.00000000, 0.00000000],
#        [0.00000000, 0.00000000, 0.00000000, 0.00000207, 0.99432097, 0.00000000, 0.00567696, 0.00000000],
#        [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00001795],
#        [0.00000069, 0.33993085, 0.66006319, 0.00000000, 0.00000000, 0.00000528, 0.00000000, 0.00000000]])