Linear#
- class torch.nn.modules.linear.Linear(in_features, out_features, bias=True, device=None, dtype=None)[源代码]#
对输入数据应用仿射线性变换: .
此模块支持 TensorFloat32。
在某些 ROCm 设备上,当使用 float16 输入时,此模块将对反向传播使用不同精度。
- 参数
in_features (int) – 每个输入样本的大小
out_features (int) – 每个输出样本的大小
bias (bool) – 如果设置为
False
,则该层将不学习加性偏置。默认值:True
- 形状
输入: ,其中 表示任意数量的维度(包括零个),并且 。
输出: ,其中除了最后一个维度外,所有维度都与输入形状相同,并且 。
- 变量
weight (torch.Tensor) – 模块的可学习权重,形状为 。其值从 初始化,其中 。
bias – 模块的可学习偏差,形状为 。如果
bias
为True
,则值从 初始化,其中 。
示例
>>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30])