当前位置:网站首页>Deep learning - (5) class of data imbalance_ weight
Deep learning - (5) class of data imbalance_ weight
2022-07-21 19:36:00 【Pomelo flavored sheep】
Deep learning ——(5) Data imbalance class_weight
List of articles
In the process of classifying images or other data , When the number of samples in each category is not equal or even very different —— The data is unbalanced , This situation will make model In the training process, one batch All the samples taken in belong to the same category , The features learned by samples of the same category are all equal , Caused on the training set loss Low value , But the accuracy of the verification set is not high .
You can calculate loss Add when the value is class_weight, The specific process is as follows :
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
classes=[190,90,121,57,28]
label = np.zeros(train_num)
for i in range(classes[0]):
label[i]=0
for i in range(classes[1]):
label[classes[0]+i]=1
for i in range(classes[2]):
label[classes[0]+classes[1]+i]=2
for i in range(classes[3]):
label[classes[0]+classes[1]+classes[2]+i]=3
for i in range(classes[4]):
label[classes[0]+classes[1]+classes[2]+classes[3]+i]=4
class_weights=compute_class_weight('balanced',np.unique(label),label)
class_weights=torch.tensor(class_weights,dtype=torch.float) # Data imbalance , stay loss Used in the calculation class_weight, Give each category a weight
loss_function = nn.CrossEntropyLoss(class_weights)
Calculated after loss that will do , The above code is CPU There is no problem running on the computer , But put it in GPU Running up will report an error , as follows :
RuntimeError: Expected object of device type cuda but got device type cpu for argument #3 'weight' in call to _thnn_nll_loss_forward
The reason for the error is that cpu Run up ,device Is directly cpu, Direct use tensor Just go . But in GPU Be sure to run loss Parameters in class_weight Load into cuda.
Solution
Add .to(device)
, for example :
loss_function = nn.CrossEntropyLoss(class_weights.to(device))
886, Let yourself form the good habit of recording casually ~
边栏推荐
- 人事部门OKR案例:为同事创造最佳办公环境
- Distributed load balancing
- Distributed High availability and high scalability index
- Use the mogdb operator to deploy the mogdb cluster (mogdb stack) on kubernetes
- C语言文件操作
- Solve the error: uncaught typeerror: cannot read properties of undefined (reading 'install')
- 函数、方法和接口的区别
- ClickHouse深度揭秘
- LeetCode:1260. Two dimensional mesh migration [one dimensional expansion + splicing]
- Getting to know Clickhouse for the first time -- installation and introduction
猜你喜欢
MySQL advanced (b)
342个中、英文等NLP开源数据集分享
[cloud native] what if SQL (and stored procedures) runs too slowly?
C what are the output points of DSP core resample of digital signal processing toolkit
h5在微信内自定义分享遇到的坑
Thread pool Thread number setting
Day009 circular structure (exercise)
ASTM F 814 test method for specific optical density of smoke produced by solid materials for aerospace equipment
LeetCode:1260. Two dimensional mesh migration [one dimensional expansion + splicing]
Solve the error: uncaught typeerror: cannot read properties of undefined (reading 'install')
随机推荐
Software testing interview question: what are the common design methods of test cases for black box testing? Please use specific examples to illustrate the application of these methods in test case de
MySQL进阶(B)
Preparation of paclitaxel combined with 2-methoxyestradiol albumin nanoparticles / Piper longum amide albumin nanoparticles
Web3流量聚合平台Starfish OS,给玩家元宇宙新范式体验
Distributed Cap theory
Differences between functions, methods and interfaces
LBA converted to CHS formula
载二氢丹参酮Ⅰ白蛋白纳米粒/去甲斑蝥素白蛋白纳米粒/伏立康唑白蛋白纳米粒的制备研究
Lombok simplifies development
Distributed High performance
Distributed What index is high concurrency
342 NLP open source datasets in Chinese and English are shared
作业正则 sed
SAP smartforms print failure message type: ssfcomposer message number: 601 (currency and number field setting reference and format)
腾讯IM实战:低代码超快实现即时通讯录
What is the difference between the tag attribute href of a reference URL and Src?
蚓激酶白蛋白纳米粒/红细胞膜定向包裹血红蛋白-白蛋白纳米粒的研究制备
Clickhouse in-depth disclosure
Distributed Common architectures and service splitting
MySQL advanced (b)