如何在Matplotlib中绘制混淆矩阵
参考: How can I plot a confusion matrix in matplotlib
混淆矩阵是一种特别的矩阵,用于可视化机器学习算法的性能,特别是在分类问题中。它显示了实际类别与模型预测类别的对比。在本文中,我们将详细介绍如何使用Python的Matplotlib库来绘制混淆矩阵。我们将从基础知识开始,逐步深入到更复杂的可视化技巧。
基本混淆矩阵的绘制
首先,我们需要安装和导入必要的库。确保你已经安装了matplotlib
和numpy
。
import matplotlib.pyplot as plt
import numpy as np
接下来,我们创建一个简单的混淆矩阵,并使用matshow()
函数来绘制它。
import matplotlib.pyplot as plt
import numpy as np
# 示例代码 1
conf_matrix = np.array([[5, 2], [3, 5]])
plt.matshow(conf_matrix, cmap='Blues')
plt.title('Confusion Matrix - how2matplotlib.com')
plt.colorbar()
plt.show()
Output:
添加标签和注释
为了使混淆矩阵更易于理解,我们可以添加轴标签和每个单元格的注释。
import matplotlib.pyplot as plt
import numpy as np
# 示例代码 2
conf_matrix = np.array([[10, 1], [2, 10]])
fig, ax = plt.subplots()
cax = ax.matshow(conf_matrix, cmap='Oranges')
fig.colorbar(cax)
ax.set_xticklabels([''] + ['Positive', 'Negative'])
ax.set_yticklabels([''] + ['True', 'False'])
plt.title('Confusion Matrix with Labels - how2matplotlib.com')
for (i, j), val in np.ndenumerate(conf_matrix):
ax.text(j, i, f'{val}', ha='center', va='center', color='black')
plt.show()
使用不同的颜色映射
Matplotlib提供了多种颜色映射选项,可以帮助更好地区分不同的值。
import matplotlib.pyplot as plt
import numpy as np
# 示例代码 3
conf_matrix = np.array([[15, 5], [5, 15]])
plt.matshow(conf_matrix, cmap='coolwarm')
plt.title('Confusion Matrix with Color Map - how2matplotlib.com')
plt.colorbar()
plt.show()
Output:
调整颜色条的位置和大小
我们可以调整颜色条的位置和大小,以更好地适应图形的布局。
import matplotlib.pyplot as plt
import numpy as np
# 示例代码 4
conf_matrix = np.array([[20, 2], [2, 20]])
fig, ax = plt.subplots()
cax = ax.matshow(conf_matrix, cmap='Greens')
fig.colorbar(cax, fraction=0.046, pad=0.04)
plt.title('Confusion Matrix with Adjusted Colorbar - how2matplotlib.com')
plt.show()
Output:
增加网格线
为了更清晰地区分每个单元格,我们可以添加网格线。
import matplotlib.pyplot as plt
import numpy as np
# 示例代码 5
conf_matrix = np.array([[25, 5], [10, 25]])
plt.matshow(conf_matrix, cmap='Purples')
plt.title('Confusion Matrix with Grid - how2matplotlib.com')
plt.grid(True)
plt.colorbar()
plt.show()
Output:
使用不同的字体和字体大小
我们可以改变注释的字体和大小,使其更符合整体的视觉风格。
import matplotlib.pyplot as plt
import numpy as np
# 示例代码 6
conf_matrix = np.array([[30, 10], [10, 30]])
fig, ax = plt.subplots()
cax = ax.matshow(conf_matrix, cmap='Reds')
fig.colorbar(cax)
ax.set_xticklabels([''] + ['Positive', 'Negative'])
ax.set_yticklabels([''] + ['True', 'False'])
plt.title('Confusion Matrix with Custom Fonts - how2matplotlib.com')
for (i, j), val in np.ndenumerate(conf_matrix):
ax.text(j, i, f'{val}', ha='center', va='center', color='white', fontsize=14, fontweight='bold')
plt.show()
结论
在本文中,我们介绍了如何使用Matplotlib绘制混淆矩阵,并通过添加标签、注释、调整颜色和样式等方式增强了图形的可读性和美观性。