Skip to content

Commit 8fd2641

Browse files
authored
docs: add LLM-as-a-Judge example (#804)
* download and process HaluEval of the data * Show how to use a table based predictor * Add visualizations * Add naive threshold (also added to quickstart) *update history
1 parent da1043d commit 8fd2641

File tree

3 files changed

+278
-8
lines changed

3 files changed

+278
-8
lines changed

HISTORY.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ History
44

55
1.x.x (2025-xx-xx)
66
------------------
7+
* Add an example of risk control with LLM as a judge
8+
* Add comparison with naive threshold in risk control quick start example
79

810
1.2.0 (2025-11-17)
911
------------------

examples/risk_control/1-quickstart/plot_risk_control_binary_classification.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
88
"""
99

10+
# sphinx_gallery_thumbnail_number = 2
11+
1012
import matplotlib.pyplot as plt
1113
import numpy as np
1214
from sklearn.datasets import make_circles
@@ -18,7 +20,7 @@
1820
from mapie.risk_control import BinaryClassificationController
1921
from mapie.utils import train_conformalize_test_split
2022

21-
RANDOM_STATE = 1
23+
RANDOM_STATE = 42
2224

2325
##############################################################################
2426
# First, load the dataset and then split it into training, calibration
@@ -28,9 +30,9 @@
2830
(X_train, X_calib, X_test, y_train, y_calib, y_test) = train_conformalize_test_split(
2931
X,
3032
y,
31-
train_size=0.8,
33+
train_size=0.7,
3234
conformalize_size=0.1,
33-
test_size=0.1,
35+
test_size=0.2,
3436
random_state=RANDOM_STATE,
3537
)
3638

@@ -112,7 +114,7 @@
112114
f"{len(bcc.valid_predict_params)} thresholds found that guarantee a precision of "
113115
f"at least {target_precision} with a confidence of {confidence_level}.\n"
114116
"Among those, the one that maximizes the secondary objective (recall here) is: "
115-
f"{bcc.best_predict_param:.3f}."
117+
f"{bcc.best_predict_param:.2f}."
116118
)
117119

118120

@@ -128,6 +130,10 @@
128130
y_pred = (proba_positive_class >= threshold).astype(int)
129131
precisions[i] = precision_score(y_calib, y_pred)
130132

133+
naive_threshold_index = np.argmin(
134+
np.where(precisions >= target_precision, precisions - target_precision, np.inf)
135+
)
136+
131137
valid_thresholds_indices = np.array(
132138
[t in bcc.valid_predict_params for t in tested_thresholds]
133139
)
@@ -155,6 +161,15 @@
155161
edgecolors="k",
156162
s=300,
157163
)
164+
plt.scatter(
165+
tested_thresholds[naive_threshold_index],
166+
precisions[naive_threshold_index],
167+
c="tab:red",
168+
label="Naive threshold",
169+
marker="*",
170+
edgecolors="k",
171+
s=300,
172+
)
158173
plt.axhline(target_precision, color="tab:gray", linestyle="--")
159174
plt.text(
160175
0.7,
@@ -168,9 +183,28 @@
168183
plt.legend()
169184
plt.show()
170185

186+
proba_positive_class_test = clf.predict_proba(X_test)[:, 1]
187+
y_pred_naive = (
188+
proba_positive_class_test >= tested_thresholds[naive_threshold_index]
189+
).astype(int)
190+
print(
191+
"With the naive threshold, the precision is:\n "
192+
f"- {precisions[naive_threshold_index]:.3f} on the calibration set\n "
193+
f"- {precision_score(y_test, y_pred_naive):.3f} on the test set."
194+
)
195+
196+
print(
197+
"\n\n With risk control, the precision is:\n"
198+
f"- {precisions[best_threshold_index]:.3f} on the calibration set\n "
199+
f"- {precision_score(y_test, bcc.predict(X_test)):.3f} on the test set."
200+
)
201+
171202
##############################################################################
172203
# Contrary to the naive way of computing a threshold to satisfy a precision target on
173204
# calibration data, risk control provides statistical guarantees on unseen data.
205+
# In this example, the naive threshold results in a precision on the test set that is
206+
# lower than the target precision while risk control takes a margin to guarantee
207+
# the target precision on unseen data with high probability.
174208
# In the plot above, we can see that not all thresholds corresponding to a precision
175209
# higher than the target are valid. This is due to the uncertainty inherent to the
176210
# finite size of the calibration set, which risk control takes into account.
@@ -179,12 +213,14 @@
179213
# small number of observations used to compute the precision, following the Learn Then
180214
# Test procedure. In the most extreme case, no observation is available, which causes
181215
# the precision value to be ill-defined and set to 0.
182-
216+
#
183217
# Besides computing a set of valid thresholds,
184218
# :class:`~mapie.risk_control.BinaryClassificationController` also outputs the "best"
185219
# one, which is the valid threshold that maximizes a secondary objective
186220
# (recall here).
187-
#
221+
222+
223+
##############################################################################
188224
# After obtaining the best threshold, we can use the ``predict`` function of
189225
# :class:`~mapie.risk_control.BinaryClassificationController` for future predictions,
190226
# or use scikit-learn's ``FixedThresholdClassifier`` as a wrapper to benefit
@@ -206,15 +242,15 @@
206242
X_test[y_test == 0, 1],
207243
edgecolors="k",
208244
c="tab:blue",
209-
alpha=0.5,
245+
alpha=0.3,
210246
label='"negative" class',
211247
)
212248
plt.scatter(
213249
X_test[y_test == 1, 0],
214250
X_test[y_test == 1, 1],
215251
edgecolors="k",
216252
c="tab:red",
217-
alpha=0.5,
253+
alpha=0.3,
218254
label='"positive" class',
219255
)
220256
plt.title("Decision Boundary of FixedThresholdClassifier")
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""
2+
Risk Control for LLM as a Judge
3+
===============================
4+
5+
This example demonstrates how to use risk control methods for Large Language Models (LLMs) acting as judges.
6+
We simulate a scenario where an LLM evaluates answers, and we want to control the risk of hallucination detection.
7+
"""
8+
9+
# sphinx_gallery_thumbnail_number = 2
10+
11+
import numpy as np
12+
import pandas as pd
13+
from matplotlib import pyplot as plt
14+
from sklearn.metrics import precision_score
15+
from sklearn.model_selection import train_test_split
16+
17+
from mapie.risk_control import BinaryClassificationController
18+
19+
np.random.seed(0)
20+
21+
##############################################################################
22+
# First, we load HaluEval Question-Answering Data, an open-source dataset for evaluating hallucination in LLMs.
23+
# Then, we preprocess the data to create a suitable format for our analysis.
24+
url = "https://raw.githubusercontent.com/RUCAIBox/HaluEval/main/data/qa_data.json"
25+
df = pd.read_json(url, lines=True)
26+
print("Sample of the original dataset:\n\n", df.iloc[0])
27+
28+
# Melt the dataframe to combine right_answer and hallucinated_answer into a single column
29+
df = df.melt(
30+
id_vars=["knowledge", "question"],
31+
value_vars=["right_answer", "hallucinated_answer"],
32+
var_name="answer_type",
33+
value_name="answer",
34+
ignore_index=False, # Keep the original index to allow sorting back to pairs
35+
)
36+
37+
# Sort by index to keep the pairs together (right_answer and hallucinated_answer for
38+
# the same question)
39+
df = df.sort_index()
40+
41+
# Create the 'hallucinated' flag based on the original column name and drop the helper
42+
# column 'answer_type'
43+
df["hallucinated"] = df["answer_type"] == "hallucinated_answer"
44+
df = df.drop(columns=["answer_type"])
45+
df = df.reset_index(drop=True)
46+
47+
# Create judge input prompts
48+
df["judge_input"] = df.apply(
49+
lambda row: f"""
50+
You are a judge evaluating whether an answer to a question is faithful to the
51+
provided knowledge snippet.
52+
53+
Knowledge: {row["knowledge"]}
54+
Question: {row["question"]}
55+
Answer: {row["answer"]}
56+
57+
Does the answer contain information that is NOT supported by the knowledge?
58+
59+
Provide a score between 0.0 and 1.0 indicating the probability that the answer is a
60+
hallucination.
61+
""",
62+
axis=1,
63+
)
64+
65+
print("Sample of the processed dataset:\n\n", df.iloc[0])
66+
67+
68+
##############################################################################
69+
# For demonstration purposes, we simulate the LLM judge's behavior using a simple table-based predictor.
70+
# In practice, you would replace this with actual LLM API calls to get judge scores or read from a file
71+
# of judge scores obtained from a complex LangChain pipeline for instance.
72+
73+
74+
class TableBasePredictor:
75+
def __init__(self, df):
76+
df["judge_score"] = df["hallucinated"].apply(self.generate_biased_score)
77+
self.df = df[["judge_input", "judge_score"]]
78+
self.df = self.df.set_index("judge_input")
79+
80+
def predict_proba(self, X):
81+
score_positive = self.df.loc[X]["judge_score"].values
82+
score_negative = 1 - score_positive
83+
return np.vstack([score_negative, score_positive]).T
84+
85+
@staticmethod
86+
def generate_biased_score(is_hallucinated):
87+
"""Generate a biased score based on whether the answer is hallucinated."""
88+
if is_hallucinated:
89+
return np.random.beta(a=3, b=1)
90+
else:
91+
return np.random.beta(a=1, b=3)
92+
93+
94+
llm_judge = TableBasePredictor(df)
95+
96+
plt.figure()
97+
plt.hist(
98+
df[df["hallucinated"]]["judge_score"],
99+
bins=30,
100+
alpha=0.8,
101+
label="Hallucinated answer",
102+
density=True,
103+
)
104+
plt.hist(
105+
df[~df["hallucinated"]]["judge_score"],
106+
bins=30,
107+
alpha=0.8,
108+
label="Correct answer",
109+
density=True,
110+
)
111+
plt.xlabel("Judge Score (Probability of Hallucination)")
112+
plt.ylabel("Density")
113+
plt.title("Distribution of Judge Scores")
114+
plt.legend()
115+
plt.show()
116+
117+
##############################################################################
118+
# Next, we split the data into calibration and test sets. We then initialize a
119+
# :class:`~mapie.risk_control.BinaryClassificationController` using the LLM judge's
120+
# probability estimation function, a risk metric (here, "precision"), a target risk level,
121+
# and a confidence level. We use the calibration data to compute statistically guaranteed thresholds.
122+
123+
X = df["judge_input"].to_numpy()
124+
y = df["hallucinated"].astype(int)
125+
126+
X_calib, X_test, y_calib, y_test = train_test_split(X, y, test_size=0.8, random_state=0)
127+
target_precision = 0.9
128+
confidence_level = 0.9
129+
130+
bcc = BinaryClassificationController(
131+
predict_function=llm_judge.predict_proba,
132+
risk="precision",
133+
target_level=target_precision,
134+
confidence_level=confidence_level,
135+
best_predict_param_choice="recall",
136+
)
137+
bcc.calibrate(X_calib, y_calib)
138+
139+
print(f"The best threshold is: {bcc.best_predict_param}")
140+
141+
y_calib_pred_controlled = bcc.predict(X_calib)
142+
precision_calib = precision_score(y_calib, y_calib_pred_controlled)
143+
144+
y_test_pred_controlled = bcc.predict(X_test)
145+
precision_test = precision_score(y_test, y_test_pred_controlled)
146+
147+
print(
148+
"With risk control, the precision is:\n"
149+
f"- {precision_calib:.3f} on the calibration set \n"
150+
f"- {precision_test:.3f} on the test set."
151+
)
152+
153+
##############################################################################
154+
# Finally, let us visualize the precision achieved on the calibration set for
155+
# the tested thresholds, highlighting the valid thresholds and the best one
156+
# (which maximizes recall).
157+
158+
proba_positive_class = llm_judge.predict_proba(X_calib)[:, 1]
159+
160+
tested_thresholds = bcc._predict_params
161+
precisions = np.full(len(tested_thresholds), np.inf)
162+
for i, threshold in enumerate(tested_thresholds):
163+
y_pred = (proba_positive_class >= threshold).astype(int)
164+
precisions[i] = precision_score(y_calib, y_pred)
165+
166+
naive_threshold_index = np.argmin(
167+
np.where(precisions >= target_precision, precisions - target_precision, np.inf)
168+
)
169+
naive_threshold = tested_thresholds[naive_threshold_index]
170+
171+
valid_thresholds_indices = np.array(
172+
[t in bcc.valid_predict_params for t in tested_thresholds]
173+
)
174+
best_threshold_index = np.where(tested_thresholds == bcc.best_predict_param)[0][0]
175+
176+
plt.figure()
177+
plt.scatter(
178+
tested_thresholds[valid_thresholds_indices],
179+
precisions[valid_thresholds_indices],
180+
c="tab:green",
181+
label="Valid thresholds",
182+
)
183+
plt.scatter(
184+
tested_thresholds[~valid_thresholds_indices],
185+
precisions[~valid_thresholds_indices],
186+
c="tab:red",
187+
label="Invalid thresholds",
188+
)
189+
plt.scatter(
190+
tested_thresholds[best_threshold_index],
191+
precisions[best_threshold_index],
192+
c="tab:green",
193+
label="Best threshold",
194+
marker="*",
195+
edgecolors="k",
196+
s=300,
197+
)
198+
plt.scatter(
199+
tested_thresholds[naive_threshold_index],
200+
precisions[naive_threshold_index],
201+
c="tab:red",
202+
label="Naive threshold",
203+
marker="*",
204+
edgecolors="k",
205+
s=300,
206+
)
207+
plt.axhline(target_precision, color="tab:gray", linestyle="--")
208+
plt.text(
209+
0.7,
210+
target_precision + 0.02,
211+
"Target precision",
212+
color="tab:gray",
213+
fontstyle="italic",
214+
)
215+
plt.xlabel("Threshold")
216+
plt.ylabel("Precision")
217+
plt.legend()
218+
plt.show()
219+
220+
proba_positive_class_test = llm_judge.predict_proba(X_test)[:, 1]
221+
y_pred_naive = (proba_positive_class_test >= naive_threshold).astype(int)
222+
223+
print(
224+
"With the naive threshold, the precision is:\n"
225+
f"- {precisions[naive_threshold_index]:.3f} on the calibration set\n"
226+
f"- {precision_score(y_test, y_pred_naive):.3f} on the test set."
227+
)
228+
229+
##############################################################################
230+
# While the naive threshold achieves the target precision on the calibration set,
231+
# it fails to do so on the test set. This highlights the importance of using
232+
# risk control methods to ensure that performance guarantees hold on unseen data.

0 commit comments

Comments
 (0)