大家从python基础到如今的入门,想必都对python有一定基础,今天小编给大家带来一个关于python的高阶内容——绘制混淆矩阵,一起来看下吧~
介绍:
混淆矩阵通过表示正确/不正确标签的计数来表示模型在表格格式中的准确性。
计算/绘制混淆矩阵:
以下是计算混淆矩阵的过程。
您需要一个包含预期结果值的测试数据集或验证数据集。
对测试数据集中的每一行进行预测。
从预期的结果和预测计数:
每个类的正确预测数量。
每个类的错误预测数量,由预测的类组织。
然后将这些数字组织成表格或矩阵,如下所示:
Expected down the side:矩阵的每一行都对应一个预测的类。
Predicted across the top:矩阵的每一列对应于一个实际的类。
然后将正确和不正确分类的计数填入表格中。
Reading混淆矩阵:
一个类的正确预测的总数进入该类值的预期行,以及该类值的预测列。
以同样的方式,一个类别的不正确预测总数进入该类别值的预期行,以及该类别值的预测列。
对角元素表示预测标签等于真实标签的点的数量,而非对角线元素是分类器错误标记的元素。混淆矩阵的对角线值越高越好,表明许多正确的预测。
用Python绘制混淆矩阵 :
importitertools
importnumpyasnp
importmatplotlib.pyplotasplt
fromsklearnimportsvm,datasets
fromsklearn.model_selectionimporttrain_test_split
fromsklearn.metricsimportconfusion_matrix
#importsomedatatoplaywith
iris=datasets.load_iris()
X=iris.data
y=iris.target
class_names=iris.target_names
#Splitthedataintoatrainingsetandatestset
X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
#Runclassifier,usingamodelthatistooregularized(Ctoolow)tosee
#theimpactontheresults
classifier=svm.SVC(kernel='linear',C=0.01)
y_pred=classifier.fit(X_train,y_train).predict(X_test)
defplot_confusion_matrix(cm,classes,
normalize=False,
title='Confusionmatrix',
cmap=plt.cm.Blues):
"""
Thisfunctionprintsandplotstheconfusionmatrix.
Normalizationcanbeappliedbysetting`normalize=True`.
"""
ifnormalize:
cm=cm.astype('float')/cm.sum(axis=1)[:,np.newaxis]
print("Normalizedconfusionmatrix")
else:
print('Confusionmatrix,withoutnormalization')
print(cm)
plt.imshow(cm,interpolation='nearest',cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks=np.arange(len(classes))
plt.xticks(tick_marks,classes,rotation=45)
plt.yticks(tick_marks,classes)
fmt='.2f'ifnormalizeelse'd'
thresh=cm.max()/2.
fori,jinitertools.product(range(cm.shape[0]),range(cm.shape[1])):
plt.text(j,i,format(cm[i,j],fmt),
horizontalalignment="center",
color="white"ifcm[i,j]>threshelse"black")
color="white"ifcm[i,j]>threshelse"black")
plt.tight_layout()
plt.ylabel('Truelabel')
plt.xlabel('Predictedlabel')
#Computeconfusionmatrix
cnf_matrix=confusion_matrix(y_test,y_pred)
np.set_printoptions(precision=2)
#Plotnon-normalizedconfusionmatrix
plt.figure()
plot_confusion_matrix(cnf_matrix,classes=class_names,
title='Confusionmatrix,withoutnormalization')
#Plotnormalizedconfusionmatrix
plt.figure()
plot_confusion_matrix(cnf_matrix,classes=class_names,normalize=True,
title='Normalizedconfusionmatrix')
plt.show()
Confusionmatrix,withoutnormalization [[1300] [0106] [009]] Normalizedconfusionmatrix [[1.0.0.] [0.0.620.38] [0.0.1.]]
好了,大家可以消化学习下哦~如需了解更多python实用知识,点击进入PyThon学习网教学中心。
上一篇
下一篇