Confusion_Matrix.py
This commit is contained in:
47
tools/sklearn/Confusion_Matrix.py
Normal file
47
tools/sklearn/Confusion_Matrix.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Authors: The scikit-learn developers
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sklearn import datasets, svm
|
||||||
|
from sklearn.metrics import ConfusionMatrixDisplay
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
# import some data to play with
|
||||||
|
iris = datasets.load_iris()
|
||||||
|
X = iris.data
|
||||||
|
y = iris.target
|
||||||
|
class_names = iris.target_names
|
||||||
|
|
||||||
|
# Split the data into a training set and a test set
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
||||||
|
|
||||||
|
# Run classifier, using a model that is too regularized (C too low) to see
|
||||||
|
# the impact on the results
|
||||||
|
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)
|
||||||
|
|
||||||
|
np.set_printoptions(precision=2)
|
||||||
|
|
||||||
|
# Plot non-normalized confusion matrix
|
||||||
|
titles_options = [
|
||||||
|
("Confusion matrix, without normalization", None),
|
||||||
|
("Normalized confusion matrix", "true"),
|
||||||
|
]
|
||||||
|
for title, normalize in titles_options:
|
||||||
|
disp = ConfusionMatrixDisplay.from_estimator(
|
||||||
|
classifier,
|
||||||
|
X_test,
|
||||||
|
y_test,
|
||||||
|
display_labels=class_names,
|
||||||
|
cmap=plt.cm.Blues,
|
||||||
|
normalize=normalize,
|
||||||
|
)
|
||||||
|
disp.ax_.set_title(title)
|
||||||
|
|
||||||
|
print(title)
|
||||||
|
print(disp.confusion_matrix)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
Reference in New Issue
Block a user