三招提升数据不平衡模型的性能(附python代码)

简介: 本文的主要目标是处理数据不平衡问题。文中描述了用来克服数据不平衡问题的三种技术,分别是集成交叉验证、类别权重以及过大预测 。

       对于深度学习而言,数据集非常重要,但在实际项目中,或多或少会碰见数据不平衡问题。什么是数据不平衡呢?举例来说,现在有一个任务是判断西瓜是否成熟,这是一个二分类问题——西瓜是生的还是熟的,该任务的数据集由两部分数据组成,成熟西瓜与生西瓜,假设生西瓜的样本数量远远大于成熟西瓜样本的数量,针对这样的数据集训练出来的算法“偏向”于识别新样本为生西瓜,存心让你买不到甜的西瓜以解夏天之苦,这就是一个数据不平衡问题。针对数据不平衡问题有相应的处理办法,比如对多数样本进行采样使得其样本数量级与少样本数相近,或者是对少数样本重复使用等。最近恰好在面试中遇到一个数据不平衡问题,这也是面试中经常会出现的问题之一,现向读者分享此次解决问题的心得。

1_jpeg

数据集

       训练数据中有三个标签,分别标记为[1、2、3],这意味着该问题是一个多分类问题。训练数据集有17个特征以及38829个独立数据点。而在测试数据中,有16个没有标签的特征和16641个数据点。该训练数据集非常不平衡,大部分数据是1类(95%),而2类和3类分别有3.0%和0.87%的数据,如下图所示。

2

算法

       经过初步观察,决定采用随机森林(RF)算法,因为它优于支持向量机、Xgboost以及LightGBM算法。在这个项目中选择RF还有几个原因:

  • 1机森林对过拟合具有很强的鲁棒性;
  • 2.参数化仍然非常直观;
  • 3.在这个项目中,有许多成功的用例将随机森林算法用于高度不平衡的数据集;
  • 4.个人有先前的算法实施经验;
           为了找到最佳参数,使用scikit-sklearn实现的GridSearchCV对指定的参数值执行网格搜索,更多细节可以在本人的Github上找到。

为了处理数据不平衡问题,使用了以下三种技术:

A.使用集成交叉验证(CV):

       在这个项目中,使用交叉验证来验证模型的鲁棒性。整个数据集被分成五个子集。在每个交叉验证中,使用其中的四个子集用于训练,剩余的子集用于验证模型,此外模型还对测试数据进行了预测。在交叉验证结束时,会得到五个测试预测概率。最后,对所有类别的概率取平均值。模型的训练表现稳定,每个交叉验证上具有稳定的召回率和f1分数。这项技术也帮助我在Kaggle比赛中取得了很好的成绩(前1%)。以下部分代码片段显示了集成交叉验证的实现:

for j, (train_idx, valid_idx) in enumerate(folds):
                
                X_train = X[train_idx]
                Y_train = y[train_idx]
                X_valid = X[valid_idx]
                Y_valid = y[valid_idx]
                
                clf.fit(X_train, Y_train)
                
                valid_pred = clf.predict(X_valid)
                recall  = recall_score(Y_valid, valid_pred, average='macro')
                f1 = f1_score(Y_valid, valid_pred, average='macro')
                
                recall_scores[i][j] = recall
                f1_scores[i][j] = f1
                
                train_pred[valid_idx, i] = valid_pred
                test_pred[:, test_col] = clf.predict(T)
                test_col += 1
                
                ## Probabilities
                valid_proba = clf.predict_proba(X_valid)
                train_proba[valid_idx, :] = valid_proba
                test_proba  += clf.predict_proba(T)
                
            test_proba /= self.n_splits

B.设置类别权重/重要性:

       代价敏感学习是使随机森林更适合从非常不平衡的数据中学习的方法之一。随机森林有倾向于偏向大多数类别。因此,对少数群体错误分类施加昂贵的惩罚可能是有作用的。由于这种技术可以改善模型性能,所以我给少数群体分配了很高的权重(即更高的错误分类成本)。然后将类别权重合并到随机森林算法中。我根据类别1中数据集的数量与其它数据集的数量之间的比率来确定类别权重。例如,类别1和类别3数据集的数目之间的比率约为110,而类别1和类别2的比例约为26。现在我稍微对数量进行修改以改善模型的性能,以下代码片段显示了不同类权重的实现:

from sklearn.ensemble import RandomForestClassifier
class_weight = dict({1:1.9, 2:35, 3:180})

rdf = RandomForestClassifier(bootstrap=True,
            class_weight=class_weight, 
            criterion='gini',
            max_depth=8, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=4, min_samples_split=10,
            min_weight_fraction_leaf=0.0, n_estimators=300,
            oob_score=False,
            random_state=random_state,
            verbose=0, warm_start=False)

C.过大预测标签而不是过小预测(Over-Predict a Label than Under-Predict):

       这项技术是可选的,通过实践发现,这种方法对提高少数类别的表现非常有效。简而言之,如果将模型错误分类为类别3,则该技术能最大限度地惩罚该模型,对于类别2和类别1惩罚力度稍差一些。 为了实施该方法,我改变了每个类别的概率阈值,将类别3、类别2和类别1的概率设置为递增顺序(即,P3= 0.25,P2= 0.35,P1= 0.50),以便模型被迫过度预测类别。该算法的详细实现可以在Github上找到。

最终结果

       以下结果表明,上述三种技术如何帮助改善模型性能:
1.使用集成交叉验证的结果:

3


2.使用集成交叉验证+类别权重的结果:

4


3.使用集成交叉验证+类别权重+过大预测标签的结果:

5

结论

       由于在实施过大预测技术方面的经验很少,因此最初的时候处理起来非常棘手。但是,研究该问题有助于提升我解决问题的能力。对于每个任务而言,起初可能确实是陌生的,这个时候不要害怕,一次次尝试就好。由于时间的限制(48小时),无法将精力分散于模型的微调以及特征工程,存在改进的地方还有很多,比如删除不必要的功能并添加一些额外功能。此外,也尝试过LightGBM和XgBoost算法,但在实践过程中发现,随机森林的效果优于这两个算法。在后面的研究中,可以进一步尝试一些其他算法,比如神经网络、稀疏编码等。

数十款阿里云产品限时折扣中,赶紧点击领劵开始云上实践吧!

作者信息

Sabber Ahamed,计算地球物理学、机器学习爱好者
个人主页:https://www.linkedin.com/in/sabber-ahamed/
本文由阿里云云栖社区组织翻译。
文章原标题《Three techniques to improve machine learning model performance with imbalanced datasets》,译者:海棠,审校:Uncle_LLD。
文章为简译,更为详细的内容,请查看原文

相关文章
|
4天前
|
机器学习/深度学习 数据采集 前端开发
【Python机器学习专栏】模型泛化能力与交叉验证
【4月更文挑战第30天】本文探讨了机器学习中模型泛化能力的重要性,它是衡量模型对未知数据预测能力的关键。过拟合和欠拟合影响泛化能力,而交叉验证是评估和提升泛化能力的有效工具。通过K折交叉验证等方法,可以发现并优化模型,如调整参数、选择合适模型、数据预处理、特征选择和集成学习。Python中可利用scikit-learn的cross_val_score函数进行交叉验证。
|
4天前
|
机器学习/深度学习 数据可视化 前端开发
【Python机器学习专栏】机器学习模型评估的实用方法
【4月更文挑战第30天】本文介绍了机器学习模型评估的关键方法,包括评估指标(如准确率、精确率、召回率、F1分数、MSE、RMSE、MAE及ROC曲线)和交叉验证技术(如K折交叉验证、留一交叉验证、自助法)。混淆矩阵提供了一种可视化分类模型性能的方式,而Python的scikit-learn库则方便实现这些评估。选择适合的指标和验证方法能有效优化模型性能。
|
4天前
|
机器学习/深度学习 算法 前端开发
【Python机器学习专栏】机器学习中的模型融合技术
【4月更文挑战第30天】模型融合,即集成学习,通过结合多个模型提升预测性能。常见方法包括:Bagging(如Random Forest)、Boosting(如AdaBoost、XGBoost)和Stacking。Python中可使用`scikit-learn`实现,例如BaggingClassifier示例。模型融合是机器学习中的强大工具,能提高整体性能并适应复杂问题。
|
4天前
|
机器学习/深度学习 Python
【Python 机器学习专栏】模型选择中的交叉验证与网格搜索
【4月更文挑战第30天】交叉验证和网格搜索是机器学习中优化模型的关键技术。交叉验证通过划分数据集进行多次评估,如K折和留一法,确保模型性能的稳定性。网格搜索遍历预定义参数组合,寻找最佳参数设置。两者结合能全面评估模型并避免过拟合。Python中可使用`sklearn`库实现这一过程,但需注意计算成本、过拟合风险及数据适应性。理解并熟练应用这些方法能提升模型性能和泛化能力。
|
4天前
|
机器学习/深度学习 数据可视化 TensorFlow
【Python 机器学习专栏】使用 TensorFlow 构建深度学习模型
【4月更文挑战第30天】本文介绍了如何使用 TensorFlow 构建深度学习模型。TensorFlow 是谷歌的开源深度学习框架,具备强大计算能力和灵活编程接口。构建模型涉及数据准备、模型定义、选择损失函数和优化器、训练、评估及模型保存部署。文中以全连接神经网络为例,展示了从数据预处理到模型训练和评估的完整流程。此外,还提到了 TensorFlow 的自动微分、模型可视化和分布式训练等高级特性。通过本文,读者可掌握 TensorFlow 基本用法,为构建高效深度学习模型打下基础。
|
4天前
|
算法 数据挖掘 Python
Python贝叶斯MCMC:Metropolis-Hastings、Gibbs抽样、分层模型、收敛性评估
Python贝叶斯MCMC:Metropolis-Hastings、Gibbs抽样、分层模型、收敛性评估
12 2
|
4天前
|
机器学习/深度学习 存储 数据采集
【Python 机器学习专栏】PCA(主成分分析)在数据降维中的应用
【4月更文挑战第30天】本文探讨了主成分分析(PCA)在高维数据降维中的应用。PCA通过线性变换找到最大化方差的主成分,从而降低数据维度,简化存储和计算,同时去除噪声。文章介绍了PCA的基本原理、步骤,强调了PCA在数据降维、可视化和特征提取上的优势,并提供了Python实现示例。PCA广泛应用在图像压缩、机器学习和数据分析等领域,但降维后可能损失解释性,需注意选择合适主成分数量及数据预处理。
|
4天前
|
机器学习/深度学习 算法 Python
【Python 机器学习专栏】随机森林算法的性能与调优
【4月更文挑战第30天】随机森林是一种集成学习方法,通过构建多棵决策树并投票或平均预测结果,具有高准确性、抗过拟合、处理高维数据的能力。关键性能因素包括树的数量、深度、特征选择和样本大小。调优方法包括调整树的数量、深度,选择关键特征和参数优化。Python 示例展示了使用 GridSearchCV 进行调优。随机森林广泛应用于分类、回归和特征选择问题,是机器学习中的重要工具。
|
4天前
|
vr&ar Python
Python自激励阈值自回归(SETAR)、ARMA、BDS检验、预测分析太阳黑子时间序列数据
Python自激励阈值自回归(SETAR)、ARMA、BDS检验、预测分析太阳黑子时间序列数据
11 0
|
4天前
|
机器学习/深度学习 算法 数据挖掘
【Python 机器学习专栏】Python 中的线性回归模型详解
【4月更文挑战第30天】本文介绍了Python中的线性回归模型,包括基本原理、实现步骤和应用。线性回归假设因变量与自变量间存在线性关系,通过建立数学模型进行预测。实现过程涉及数据准备、模型构建、参数估计、评估和预测。常用的Python库有Scikit-learn和Statsmodels。线性回归简单易懂,广泛应用,但对异常值敏感且假设线性关系。其扩展形式如多元线性、多项式回归和正则化方法能适应不同场景。理解并运用线性回归有助于数据分析和预测。