|
| 1 | +""" |
| 2 | +======================================= |
| 3 | +Detection error tradeoff (DET) curve |
| 4 | +======================================= |
| 5 | +
|
| 6 | +In this example, we compare receiver operating characteristic (ROC) and |
| 7 | +detection error tradeoff (DET) curves for different classification algorithms |
| 8 | +for the same classification task. |
| 9 | +
|
| 10 | +DET curves are commonly plotted in normal deviate scale. |
| 11 | +To achieve this we transform the errors rates as returned by the |
| 12 | +``detection_error_tradeoff_curve`` function and the axis scale using |
| 13 | +``scipy.stats.norm``. |
| 14 | +
|
| 15 | +The point of this example is to demonstrate two properties of DET curves, |
| 16 | +namely: |
| 17 | +
|
| 18 | +1. It might be easier to visually assess the overall performance of different |
| 19 | + classification algorithms using DET curves over ROC curves. |
| 20 | + Due to the linear scale used for plotting ROC curves, different classifiers |
| 21 | + usually only differ in the top left corner of the graph and appear similar |
| 22 | + for a large part of the plot. On the other hand, because DET curves |
| 23 | + represent straight lines in normal deviate scale. As such, they tend to be |
| 24 | + distinguishable as a whole and the area of interest spans a large part of |
| 25 | + the plot. |
| 26 | +2. DET curves give the user direct feedback of the detection error tradeoff to |
| 27 | + aid in operating point analysis. |
| 28 | + The user can deduct directly from the DET-curve plot at which rate |
| 29 | + false-negative error rate will improve when willing to accept an increase in |
| 30 | + false-positive error rate (or vice-versa). |
| 31 | +
|
| 32 | +The plots in this example compare ROC curves on the left side to corresponding |
| 33 | +DET curves on the right. |
| 34 | +There is no particular reason why these classifiers have been chosen for the |
| 35 | +example plot over other classifiers available in scikit-learn. |
| 36 | +
|
| 37 | +.. note:: |
| 38 | +
|
| 39 | + - See :func:`sklearn.metrics.roc_curve` for further information about ROC |
| 40 | + curves. |
| 41 | +
|
| 42 | + - See :func:`sklearn.metrics.detection_error_tradeoff_curve` for further |
| 43 | + information about DET curves. |
| 44 | +
|
| 45 | + - This example is loosely based on |
| 46 | + :ref:`sphx_glr_auto_examples_classification_plot_classifier_comparison.py` |
| 47 | + . |
| 48 | +
|
| 49 | +""" |
| 50 | +import matplotlib.pyplot as plt |
| 51 | + |
| 52 | +from sklearn.model_selection import train_test_split |
| 53 | +from sklearn.preprocessing import StandardScaler |
| 54 | +from sklearn.datasets import make_classification |
| 55 | +from sklearn.svm import SVC |
| 56 | +from sklearn.ensemble import RandomForestClassifier |
| 57 | +from sklearn.metrics import detection_error_tradeoff_curve |
| 58 | +from sklearn.metrics import roc_curve |
| 59 | + |
| 60 | +from scipy.stats import norm |
| 61 | +from matplotlib.ticker import FuncFormatter |
| 62 | + |
| 63 | +N_SAMPLES = 1000 |
| 64 | + |
| 65 | +names = [ |
| 66 | + "Linear SVM", |
| 67 | + "Random Forest", |
| 68 | +] |
| 69 | + |
| 70 | +classifiers = [ |
| 71 | + SVC(kernel="linear", C=0.025), |
| 72 | + RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), |
| 73 | +] |
| 74 | + |
| 75 | +X, y = make_classification( |
| 76 | + n_samples=N_SAMPLES, n_features=2, n_redundant=0, n_informative=2, |
| 77 | + random_state=1, n_clusters_per_class=1) |
| 78 | + |
| 79 | +# preprocess dataset, split into training and test part |
| 80 | +X = StandardScaler().fit_transform(X) |
| 81 | + |
| 82 | +X_train, X_test, y_train, y_test = train_test_split( |
| 83 | + X, y, test_size=.4, random_state=0) |
| 84 | + |
| 85 | +# prepare plots |
| 86 | +fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 5)) |
| 87 | + |
| 88 | +# first prepare the ROC curve |
| 89 | +ax_roc.set_title('Receiver Operating Characteristic (ROC) curves') |
| 90 | +ax_roc.set_xlabel('False Positive Rate') |
| 91 | +ax_roc.set_ylabel('True Positive Rate') |
| 92 | +ax_roc.set_xlim(0, 1) |
| 93 | +ax_roc.set_ylim(0, 1) |
| 94 | +ax_roc.grid(linestyle='--') |
| 95 | +ax_roc.yaxis.set_major_formatter( |
| 96 | + FuncFormatter(lambda y, _: '{:.0%}'.format(y))) |
| 97 | +ax_roc.xaxis.set_major_formatter( |
| 98 | + FuncFormatter(lambda y, _: '{:.0%}'.format(y))) |
| 99 | + |
| 100 | +# second prepare the DET curve |
| 101 | +ax_det.set_title('Detection Error Tradeoff (DET) curves') |
| 102 | +ax_det.set_xlabel('False Positive Rate') |
| 103 | +ax_det.set_ylabel('False Negative Rate') |
| 104 | +ax_det.set_xlim(-3, 3) |
| 105 | +ax_det.set_ylim(-3, 3) |
| 106 | +ax_det.grid(linestyle='--') |
| 107 | + |
| 108 | +# customized ticks for DET curve plot to represent normal deviate scale |
| 109 | +ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999] |
| 110 | +tick_locs = norm.ppf(ticks) |
| 111 | +tick_lbls = [ |
| 112 | + '{:.0%}'.format(s) if (100*s).is_integer() else '{:.1%}'.format(s) |
| 113 | + for s in ticks |
| 114 | +] |
| 115 | +plt.sca(ax_det) |
| 116 | +plt.xticks(tick_locs, tick_lbls) |
| 117 | +plt.yticks(tick_locs, tick_lbls) |
| 118 | + |
| 119 | +# iterate over classifiers |
| 120 | +for name, clf in zip(names, classifiers): |
| 121 | + clf.fit(X_train, y_train) |
| 122 | + |
| 123 | + if hasattr(clf, "decision_function"): |
| 124 | + y_score = clf.decision_function(X_test) |
| 125 | + else: |
| 126 | + y_score = clf.predict_proba(X_test)[:, 1] |
| 127 | + |
| 128 | + roc_fpr, roc_tpr, _ = roc_curve(y_test, y_score) |
| 129 | + det_fpr, det_fnr, _ = detection_error_tradeoff_curve(y_test, y_score) |
| 130 | + |
| 131 | + ax_roc.plot(roc_fpr, roc_tpr) |
| 132 | + |
| 133 | + # transform errors into normal deviate scale |
| 134 | + ax_det.plot( |
| 135 | + norm.ppf(det_fpr), |
| 136 | + norm.ppf(det_fnr) |
| 137 | + ) |
| 138 | + |
| 139 | +# add a single legend |
| 140 | +plt.sca(ax_det) |
| 141 | +plt.legend(names, loc="upper right") |
| 142 | + |
| 143 | +# plot |
| 144 | +plt.tight_layout() |
| 145 | +plt.show() |
0 commit comments