当前位置:网站首页>【Mindspore-ascend】【自定义算子】GRAPH_MODE下,自定义如何遍历Tensor
【Mindspore-ascend】【自定义算子】GRAPH_MODE下,自定义如何遍历Tensor
2022-07-20 14:44:00 【小乐快乐】
因为需要使用了带weights的SoftmaxCrossEntropyLoss,就在mindspore提供的nn.SoftmaxCrossEntropyLoss上自定义了一个算子。
在该算子中,我需要遍历一维Tensor。在PYNATIVE模式中,我将Tensor转为numpy数组实现该操作。但是在GRAPH_MODE中无法将Tensor转为numpy或list,请问该如何操作?
【相关代码】
# labels_int: (n,)维Tensor
# self.ignore_label: int
# self.cls_weight: (c, )维Tensor
weights_np = np.ones((labels_int.shape[0]))
labels_np = labels_int.asnumpy()
weights_np[labels_np == self.ignore_label] = 0
cls_weight_np = self.cls_weight.asnumpy()
for idx, v in enumerate(cls_weight_np, 0):
weights_np[labels_np == idx] *= v
你的问题是想在图模式下,遍历一个一维Tensor,当前可以通过两种方式遍历Tensor。
第一种获取Tensor第一维的长度,然后通过整数索引遍历,如下:
for i in range(x.shape[0]): x[i]
第二种通过enumerate接口,直接遍历,如下:
for i, ele in enumerate(x, 0): ele
边栏推荐
猜你喜欢
【组成原理 五 系统总线】
【阿里云服务器】
微信小程序中使用vant weapp
[harmony OS] [FAQ] Hongmeng application development problem sharing (font / constructor)
QT connects to MySQL and operates the database (the clearest)
字体随窗体的变化而变化
【HMS core】【Wallet Kit】【解决方案】华为钱包的客户端示例代码为何无法运行
流量红利退去,快消品代理商如何借助RPA破局增长?
Nacos custom extended data ID configuration
LayoutInflater 布局渲染工具
随机推荐
Gap Locks(间隙锁)
思源能否内置一个密码管理器
如何用度量数据驱动代码评审的改善
常用的锂电池充电IC芯片
分布式系统中数据存储方案实践
Three paradigms of database design
一个换行符引发的思考!
【argoverse】argoverse-api 安装
Understanding of Zhongtai
js 数组reduce方法求和 求最大值 求最小值方法
C语言力扣第九题之回文数。两指针数组遍历法
MySQL pessimistic lock
DOS汇编分支、循环编程与寄存器分析
[composition principle V system bus]
Numerical conversion exercise and solution of microprocessor principle
How is the income calculated when the financial product expires?
MySQL隐式锁
VS stdio项目源文件中写多个main
WPF 实现 RichTextBox 关键字查询高亮
About Zend_ parse_ Parameters function