SMOTE算法进行过采样
算法思想
算法流程
其中,rand(0,1)的范围在 [0, 1] 之间,用于控制合成样本的位置。
对于多个少数类样本,重复以上公式的步骤,即可生成相应数量的合成样本来平衡数据集。
类别不平衡问题
SMOTE算法缺陷
代码示例
import numpy as np
import random
import matplotlib.pyplot as plt
class SMOTE(object):
def __init__(self,sample,k=2,gen_num=3):
self.sample = sample
self.sample_num,self.feature_len = self.sample.shape
self.k = min(k,self.sample_num-1)
self.gen_num = gen_num
self.syn_data = np.zeros((self.gen_num,self.feature_len))
self.k_neighbor = np.zeros((self.sample_num,self.k),dtype=int)
def get_neighbor_point(self):
for index,single_signal in enumerate(self.sample):
Euclidean_distance = np.array([np.sum(np.square(single_signal-i)) for i in self.sample])
Euclidean_distance_index = Euclidean_distance.argsort()
self.k_neighbor[index] = Euclidean_distance_index[1:self.k+1]
def get_syn_data(self):
self.get_neighbor_point()
for i in range(self.gen_num):
key = random.randint(0,self.sample_num-1)
K_neighbor_point = self.k_neighbor[key][random.randint(0,self.k-1)]
gap = self.sample[K_neighbor_point] - self.sample[key]
self.syn_data[i] = self.sample[key] + random.uniform(0,1)*gap
return self.syn_data
if __name__ == '__main__':
data = np.random.uniform(0, 1, size=[20, 2]) # 随机生成原始数据
Syntheic_sample = SMOTE(data,5,100)
new_data = Syntheic_sample.get_syn_data()
print('原数据',data)
print('生成数据',new_data)
for i in data:
plt.scatter(i[0], i[1], c='b')
# 绘制原始数据
for i in new_data:
plt.scatter(i[0], i[1], c='r', marker='^')
# 绘制生成数据
plt.show()
参考
https://blog.csdn.net/chrnhao/article/details/124045702
https://www.cnblogs.com/june0507/p/11726492.html
作者:Gloriouszh