Confusion_Matrix.py
This commit is contained in:
47
tools/sklearn/Confusion_Matrix.py
Normal file
47
tools/sklearn/Confusion_Matrix.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from sklearn import datasets, svm
|
||||
from sklearn.metrics import ConfusionMatrixDisplay
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# import some data to play with
|
||||
iris = datasets.load_iris()
|
||||
X = iris.data
|
||||
y = iris.target
|
||||
class_names = iris.target_names
|
||||
|
||||
# Split the data into a training set and a test set
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
||||
|
||||
# Run classifier, using a model that is too regularized (C too low) to see
|
||||
# the impact on the results
|
||||
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
# Plot non-normalized confusion matrix
|
||||
titles_options = [
|
||||
("Confusion matrix, without normalization", None),
|
||||
("Normalized confusion matrix", "true"),
|
||||
]
|
||||
for title, normalize in titles_options:
|
||||
disp = ConfusionMatrixDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
display_labels=class_names,
|
||||
cmap=plt.cm.Blues,
|
||||
normalize=normalize,
|
||||
)
|
||||
disp.ax_.set_title(title)
|
||||
|
||||
print(title)
|
||||
print(disp.confusion_matrix)
|
||||
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user