spectral_norm¶
- paddle.nn.utils. spectral_norm ( layer, name='weight', n_power_iterations=1, eps=1e-12, dim=None ) [source]
-
Applies spectral normalization to a parameter according to the following Calculation:
Step 1: Generate vector U in shape of [H], and V in shape of [W]. While H is the
dim
th dimension of the input weights, and W is the product result of remaining dimensions.Step 2:
n_power_iterations
should be a positive integer, do following calculations with U and V forpower_iters
rounds.\[ \begin{align}\begin{aligned}\mathbf{v} := \frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}\\\mathbf{u} := \frac{\mathbf{W} \mathbf{v}}{\|\mathbf{W} \mathbf{v}\|_2}\end{aligned}\end{align} \]Step 3: Calculate \(\sigma(\mathbf{W})\) and normalize weight values.
\[ \begin{align}\begin{aligned}\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}\\\mathbf{W} = \frac{\mathbf{W}}{\sigma(\mathbf{W})}\end{aligned}\end{align} \]Refer to Spectral Normalization .
- Parameters
-
layer (Layer) – Layer of paddle, which has weight.
name (str, optional) – Name of the weight parameter. Default: ‘weight’.
n_power_iterations (int, optional) – The number of power iterations to calculate spectral norm. Default: 1.
eps (float, optional) – The epsilon for numerical stability in calculating norms. Default: 1e-12.
dim (int, optional) – The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: None.
- Returns
-
Layer, the original layer with the spectral norm hook.
Examples
from paddle.nn import Conv2D from paddle.nn.utils import spectral_norm conv = Conv2D(3, 1, 3) sn_conv = spectral_norm(conv) print(sn_conv) # Conv2D(3, 1, kernel_size=[3, 3], data_format=NCHW) print(sn_conv.weight) # Tensor(shape=[1, 3, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=False, # [[[[-0.21090528, 0.18563725, -0.14127982], # [-0.02310637, 0.03197737, 0.34353802], # [-0.17117859, 0.33152047, -0.28408015]], # # [[-0.13336606, -0.01862637, 0.06959272], # [-0.02236020, -0.27091628, -0.24532901], # [ 0.27254242, 0.15516677, 0.09036587]], # # [[ 0.30169338, -0.28146112, -0.11768346], # [-0.45765871, -0.12504843, -0.17482486], # [-0.36866254, -0.19969313, 0.08783543]]]])