Python绘制混淆矩阵
混淆矩阵(Confusion Matrix)是用来评估分类模型性能的工具,特别是在处理二分类和多分类问题时。它以表格的形式展示了模型预测结果与真实标签之间的关系。
我首先在这里分享一下如何使用python绘制简单的混淆矩阵,后续可能会出通过python库对混淆矩阵进行计算的过程以及使用Origin绘制的过程。
代码如下:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# 假设这是你的混淆矩阵数据
confusion_matrix = np.array([
[680, 18, 133, 95, 3],
[22, 406, 25, 93, 11],
[99, 24, 348, 217, 7],
[45, 43, 83, 1149, 78],
[5, 7, 8, 127, 406]
])
# 自定义的标签名称
labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
# 设置绘图
plt.figure(figsize=(10, 7))
ax = sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues', cbar_kws={'label': 'Scale'},
xticklabels=labels, yticklabels=labels)
# 添加标题和轴标签
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# 优化标签显示
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
# 显示图形
plt.show()
绘制的效果如下:
这种方法仅适用于其中数据已知的情况,方便对颜色等的调节,但python生成的图片可能不符合一些论文对图片分辨率的要求,所以后续我会继续分享其它绘制混淆矩阵的方法。
混淆矩阵为评估分类模型提供了一种直观的方式,使得我们可以详细分析模型的表现,尤其是在类别不平衡的情况下。通过理解混淆矩阵的内容和相关指标,我们可以更有效地优化和调整模型,提升其分类性能。
作者:T.i.s