spectral_norm

paddle.nn.utils. spectral_norm ( layer, name='weight', n_power_iterations=1, eps=1e-12, dim=None ) [source]

Applies spectral normalization to a parameter according to the following Calculation:

Step 1: Generate vector U in shape of [H], and V in shape of [W]. While H is the dim th dimension of the input weights, and W is the product result of remaining dimensions.

Step 2: n_power_iterations should be a positive integer, do following calculations with U and V for power_iters rounds.

\[ \begin{align}\begin{aligned}\mathbf{v} := \frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}\\\mathbf{u} := \frac{\mathbf{W} \mathbf{v}}{\|\mathbf{W} \mathbf{v}\|_2}\end{aligned}\end{align} \]

Step 3: Calculate \(\sigma(\mathbf{W})\) and normalize weight values.

\[ \begin{align}\begin{aligned}\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}\\\mathbf{W} = \frac{\mathbf{W}}{\sigma(\mathbf{W})}\end{aligned}\end{align} \]

Refer to Spectral Normalization .

Parameters
  • layer (Layer) – Layer of paddle, which has weight.

  • name (str, optional) – Name of the weight parameter. Default: ‘weight’.

  • n_power_iterations (int, optional) – The number of power iterations to calculate spectral norm. Default: 1.

  • eps (float, optional) – The epsilon for numerical stability in calculating norms. Default: 1e-12.

  • dim (int, optional) – The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: None.

Returns

Layer, the original layer with the spectral norm hook.

Examples

from paddle.nn import Conv2D
from paddle.nn.utils import spectral_norm

conv = Conv2D(3, 1, 3)
sn_conv = spectral_norm(conv)
print(sn_conv)
# Conv2D(3, 1, kernel_size=[3, 3], data_format=NCHW)
print(sn_conv.weight)
# Tensor(shape=[1, 3, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
#        [[[[-0.21090528,  0.18563725, -0.14127982],
#           [-0.02310637,  0.03197737,  0.34353802],
#           [-0.17117859,  0.33152047, -0.28408015]],
#
#          [[-0.13336606, -0.01862637,  0.06959272],
#           [-0.02236020, -0.27091628, -0.24532901],
#           [ 0.27254242,  0.15516677,  0.09036587]],
#
#          [[ 0.30169338, -0.28146112, -0.11768346],
#           [-0.45765871, -0.12504843, -0.17482486],
#           [-0.36866254, -0.19969313,  0.08783543]]]])