Linear¶
- class paddle.nn. Linear ( in_features, out_features, weight_attr=None, bias_attr=None, name=None ) [源代码] ¶
线性变换层。对于每个输入 Tensor \(X\),计算公式为:
其中,\(W\) 和 \(b\) 分别为权重和偏置。
Linear 层只接受一个 Tensor 作为输入,形状为 \([batch\_size, *, in\_features]\),其中 \(*\) 表示可以为任意个额外的维度。 该层可以计算输入 Tensor 与权重矩阵 \(W\) 的乘积,然后生成形状为 \([batch\_size, *, out\_features]\) 的输出 Tensor。 如果 \(bias\_attr\) 不是 False,则将创建一个偏置参数并将其添加到输出中。
参数¶
in_features (int) – 线性变换层输入单元的数目。
out_features (int) – 线性变换层输出单元的数目。
weight_attr (ParamAttr,可选) – 指定权重参数的属性。默认值为 None,表示使用默认的权重参数属性。如果 \(ParamAttr\) 的初始值未设置,则使用 \(Xavier\) 初始化参数,具体用法请参见 ParamAttr 。
bias_attr (ParamAttr|bool,可选) – 指定偏置参数的属性。\(bias\_attr\) 为 bool 类型且设置为 False 时,表示不会为该层添加偏置。\(bias\_attr\) 如果设置为 True 或者 None,则表示使用默认的偏置参数属性,将偏置参数初始化为 0。具体用法请参见 ParamAttr。默认值为 None。
name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。
形状¶
输入:形状为 \([batch\_size, *, in\_features]\) 的多维 Tensor。其数据类型为 float16, float32, float64, 默认为 float32。
输出:形状为 \([batch\_size, *, out\_features]\) 的多维 Tensor。其数据类型与输入相同。
代码示例¶
>>> import paddle
>>> paddle.seed(100)
>>> # Define the linear layer.
>>> weight_attr = paddle.ParamAttr(
... name="weight",
... initializer=paddle.nn.initializer.Constant(value=0.5))
>>> bias_attr = paddle.ParamAttr(
... name="bias",
... initializer=paddle.nn.initializer.Constant(value=1.0))
>>> linear = paddle.nn.Linear(2, 4, weight_attr=weight_attr, bias_attr=bias_attr)
>>> print(linear.weight)
Parameter containing:
Tensor(shape=[2, 4], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.50000000, 0.50000000, 0.50000000, 0.50000000],
[0.50000000, 0.50000000, 0.50000000, 0.50000000]])
>>> print(linear.bias)
Parameter containing:
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=False,
[1., 1., 1., 1.])
>>> x = paddle.randn((3, 2), dtype="float32")
>>> y = linear(x)
>>> print(y)
Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=False,
[[ 0.42121571, 0.42121571, 0.42121571, 0.42121571],
[ 0.85327661, 0.85327661, 0.85327661, 0.85327661],
[-0.05398512, -0.05398512, -0.05398512, -0.05398512]])