Trang chủ‎ > ‎IT‎ > ‎Data Science - Python‎ > ‎

Plot confusing matrix

print(__doc__)

import itertools
import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)


def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()

As hinted in this question, you have to "open" the lower-level artist API, by storing the figure and axis objects passed by the matplotlib functions you call (the figax and cax variables below). You can then replace the default x- and y-axis ticks using set_xticklabels/set_yticklabels:

labels = ['business', 'health']
cm = confusion_matrix(y_test, pred, labels)
print(cm)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
pl.title('Confusion matrix of the classifier')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
pl.xlabel('Predicted')
pl.ylabel('True')
pl.show()

Note that I passed the labels list to the confusion_matrix function to make sure it's properly sorted, matching the ticks.

This results in the following figure:

enter image description here

----------------------------------------------------------------------------------------------------------------------------

You might be interested by https://github.com/pandas-ml/pandas-ml/

which implements a Python Pandas implementation of Confusion Matrix.

Some features:

  • plot confusion matrix
  • plot normalized confusion matrix
  • class statistics
  • overall statistics

Here is an example:

In [1]: from pandas_ml import ConfusionMatrix
In [2]: import matplotlib.pyplot as plt

In [3]: y_test = ['business', 'business', 'business', 'business', 'business',
        'business', 'business', 'business', 'business', 'business',
        'business', 'business', 'business', 'business', 'business',
        'business', 'business', 'business', 'business', 'business']

In [4]: y_pred = ['health', 'business', 'business', 'business', 'business',
       'business', 'health', 'health', 'business', 'business', 'business',
       'business', 'business', 'business', 'business', 'business',
       'health', 'health', 'business', 'health']

In [5]: cm = ConfusionMatrix(y_test, y_pred)

In [6]: cm
Out[6]:
Predicted  business  health  __all__
Actual
business         14       6       20
health            0       0        0
__all__          14       6       20

In [7]: cm.plot()
Out[7]: <matplotlib.axes._subplots.AxesSubplot at 0x1093cf9b0>

In [8]: plt.show()

Plot confusion matrix

In [9]: cm.print_stats()
Confusion Matrix:

Predicted  business  health  __all__
Actual
business         14       6       20
health            0       0        0
__all__          14       6       20


Overall Statistics:

Accuracy: 0.7
95% CI: (0.45721081772371086, 0.88106840959427235)
No Information Rate: ToDo
P-Value [Acc > NIR]: 0.608009812201
Kappa: 0.0
Mcnemar's Test P-Value: ToDo


Class Statistics:

Classes                                 business health
Population                                    20     20
P: Condition positive                         20      0
N: Condition negative                          0     20
Test outcome positive                         14      6
Test outcome negative                          6     14
TP: True Positive                             14      0
TN: True Negative                              0     14
FP: False Positive                             0      6
FN: False Negative                             6      0
TPR: (Sensitivity, hit rate, recall)         0.7    NaN
TNR=SPC: (Specificity)                       NaN    0.7
PPV: Pos Pred Value (Precision)                1      0
NPV: Neg Pred Value                            0      1
FPR: False-out                               NaN    0.3
FDR: False Discovery Rate                      0      1
FNR: Miss Rate                               0.3    NaN
ACC: Accuracy                                0.7    0.7
F1 score                               0.8235294      0
MCC: Matthews correlation coefficient        NaN    NaN
Informedness                                 NaN    NaN
Markedness                                     0      0
Prevalence                                     1      0
LR+: Positive likelihood ratio               NaN    NaN
LR-: Negative likelihood ratio               NaN    NaN
DOR: Diagnostic odds ratio                   NaN    NaN
FOR: False omission rate                       1      0
------------------------------------------------------------------------------------------------------------------------------------------------

Here is some discuss of coursera forum thread about confusion matrix and multi-class precision/recall measurement.

The basic idea is to compute all precision and recall of all the classes, then average them to get a single real number measurement.

Confusion matrix make it easy to compute precision and recall of a class.

Below is some basic explain about confusion matrix, copied from that thread:

A confusion matrix is a way of classifying true positives, true negatives, false positives, and false negatives, when there are more than 2 classes. It's used for computing the precision and recall and hence f1-score for multi class problems.

The actual values are represented by columns. The predicted values are represented by rows.

Examples:

10 training examples that are actually 8, are classified (predicted) incorrectly as 5
13 training examples that are actually 4, are classified incorrectly as 9

Confusion Matrix

cm =
     0     1     2       3     4       5       6     7      8       9      10
     1   298     2       1     0       1       1     3      1       1       0
     2     0     293     7     4       1       0     5      2       0       0
     3     1     3      263    0       8       0     0      3       0       2
     4     1     5       0     261     4       0     3      2       0       1
     5     0     0       10    0     254       3     0     10       2       1
     6     0     4       1     1       4       300   0      1       0       0
     7     1     3       2     0       0       0     264    0       7       1
     8     3     5       3     1       7       1     0      289     1       0
     9     0     1       3     13      1       0     11     1       289     0
    10     0     6       0     1       6       1     2      1       4       304

For class x:

  • True positive: diagonal position, cm(x, x).

  • False positive: sum of column x (without main diagonal), sum(cm(:, x))-cm(x, x).

  • False negative: sum of row x (without main diagonal), sum(cm(x, :), 2)-cm(x, x).

You can compute precision, recall and F1 score following course formula.

Averaging over all classes (with or without weighting) gives values for the entire model.

---------------------------------------------------------------------------------------------------------------------------------------------


Comments