当前位置:网站首页>技术干货 | 基于 MindSpore 实现 CosineSimilarity
技术干货 | 基于 MindSpore 实现 CosineSimilarity
2022-07-20 02:58:00 【昇思MindSpore】
Embedding Similarity 介绍
原理介绍及公式
Embedding Similarity,顾名思义就是通过嵌入向量来计算相似度,这个评价指标在网上的资料比较少,我今天来总结一哈。
相似度度量(Similarity),即计算个体间的相似程度,相似度度量的值越小,说明个体间相似度越小,相似度的值越大说明个体差异越大。
对于多个不同的文本或者短文本对话消息要来计算他们之间的相似度如何,一个好的做法就是将这些文本中词语,映射到向量空间,形成文本中文字和向量数据的映射关系,通过计算几个或者多个不同的向量的差异的大小,来计算文本的相似度。
MindSpore 代码实现
好了,原理已经讲完,话不多说,我们开始上代码。使用的是MindSpore框架实现的代码。
"""CosineSimilarity."""
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class CosineSimilarity(Metric):
def __init__(self, similarity='cosine', reduction='none', zero_diagonal=True):
super().__init__()
similarity_list = ['dot', 'cosine']
reduction_list = ['none', 'sum', 'mean']
similarity = validator.check_value_type("similarity", similarity, [str])
# 度量方式有两种,dot和cosine
self.similarity = validator.check_string(similarity, similarity_list, "similarity")
# reduction有三种,none', 'sum', 'mean'
reduction = validator.check_value_type("reduction", reduction, [str])
self.reduction = validator.check_string(reduction, reduction_list, "reduction")
self.zero_diagonal = validator.check_value_type("zero_diagonal", zero_diagonal, [bool])
self.clear()
def clear(self):
"""清除历史数据"""
self.sqr_mtx_res = 0
self._is_update = False
def update(self, *inputs):
"""
更新输入数据,输入为1个
"""
# 输入必须是一个tensor,numpy或者list
input_data = self._convert_data(inputs[0])
# 选择使用的度量方式
if self.similarity == 'cosine':
data = np.linalg.norm(input_data, ord=2, axis=1)
input_data = input_data / np.expand_dims(data, 1)
self.sqr_mtx_res = np.dot(input_data, input_data.transpose(1, 0))
self._is_update = True
def eval(self):
"""
计算cosine similarity,返回的是一个矩阵
"""
if not self._is_update:
raise RuntimeError('Call the update method before calling eval.')
if self.zero_diagonal:
np.fill_diagonal(self.sqr_mtx_res, 0)
if self.reduction == 'mean':
self.sqr_mtx_res = np.mean(self.sqr_mtx_res, axis=-1)
if self.reduction == 'sum':
self.sqr_mtx_res = np.sum(self.sqr_mtx_res, axis=-1)
return self.sqr_mtx_res
使用方法如下:
import numpy as np
from mindspore.nn.metrics import CosineSimilarity
test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]])
metric = CosineSimilarity()
metric.clear()
metric.update(test_data)
square_matrix = metric.eval()
print(square_matrix)
np.array([[0, 1, 0.78229315], [1, 0, 0.78229315], [0.78229315, 0.78229315, 0]])
这里说明一下,通常计算相似度都是有两个输入,比如A和B,通过Embedding similarity来计算两个输入的相似度,结果是一个数值。我们这里使用的是将A、B、C放到了一个矩阵里,比如 np.array([[5,8,3,2],[5,8,3,2],[4,2,3,4]]),
A为embedding之后的[5,8,3,2],
B为embedding之后的[5,8,3,2],
C为embedding之后的[4,2,3,4]。
得到的结果是np.array([[0, 1, 0.78229315], [1, 0, 0.78229315], [0.78229315, 0.78229315, 0]]),结果可表示为如下表格:
表格说明(上面每一个格对应着下面对应位置的格):
这样看结果就很清晰明了了吧,也是解释为什么输入数量为1了。不同于网上其他资料的直接两个输入x和y,拿x和y直接比的出来一个结果。这样写的原因是对于大量数据来说,更加方便,一般结果的取值范围是[-1,1]。
MindSpore官方资料
官方QQ群 : 486831414
官网:https://www.mindspore.cn/
Gitee : https : //gitee.com/mindspore/mindspore
GitHub : https://github.com/mindspore-ai/mindspore
论坛:https://bbs.huaweicloud.com/forum/forum-1076-1.html
边栏推荐
- Unity shader 实现图片带圆角和边线border
- [trivia] about some unity editors, they lack the tiles option when creating tile maps
- After reading this article, you should thoroughly understand how to do interface testing
- JVM tuning method
- ICMP - echo / echo reply (Ping) message
- At32 uses the kernel DWT register to set the delay time
- matlab-微分方程求解方法汇总
- LeetCode. 302 weekly games___ 01_ 6120. How many pairs can an array form__ Simple hash
- Technical analysis premint security events, how to avoid attacks?
- 95页智能工厂数字化、智能化规划、解决方案及建设方案2022
猜你喜欢
AT32 MCU F415 OTG新功能使用
如何在自动化测试中使用MitmProxy获取数据返回?
Win:使用 netsh 命令配置 Port Forwarding
Huawei employees revealed that this position is about to start recruiting a lot!!!
95 pages intelligent factory digitalization, intelligent planning, solutions and construction scheme 2022
LVS群集应用
Flink1.15源码阅读——flink-annotations
深度学习1-感知器
手动操纵工业机器人
私域流量和裂变营销的关系,什么是超级APP,我们企业能拥有吗?
随机推荐
Redis 主从复制&哨兵模式
[LeetCode]剑指 Offer 53 - I. 在排序数组中查找数字 I
【LeetCode-中等】34. 在排序数组中查找元素的第一个和最后一个位置 - 数组双指针
Top priority of dry goods: common indicators and terms in data analysis!
clion创建第一个C项目
20元一支的洗面奶,7天卖了上万,他们是如何做到的?
单体 or 微服务?你以为是架构权衡?其实是认知负载!
Calculate the value of any root n
最高的评价:您要走的开发事业道路做事的决心,行动是彻底的,诚恳的和绝对真实的
Compose中的FlowLayout
371 pages of 200000 words 2021 smart city informatization comprehensive construction plan
C -- string
如何在自动化测试中使用MitmProxy获取数据返回?
第九天(抓取流量、路由策略)
【obs】Transform: fit to screen
Record of force deduction and question brushing 2---35 Search insertion location
0.0.pytorch构建模型方法
NetFlow and SNMP are two different network monitoring methods
95页智能工厂数字化、智能化规划、解决方案及建设方案2022
Grouping convolution and deep separable convolution