CosineSimilarity#
- class torch.nn.CosineSimilarity(dim=1, eps=1e-08)[源代码]#
沿 dim 计算 和 之间的余弦相似度。
- 形状
输入1: ,其中 D 是 dim 位置的维度
输入2: ,与 x1 维度相同,在 dim 处的尺寸与 x1 匹配,并在其他维度上与 x1 可广播。
输出:
示例
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) >>> output = cos(input1, input2)