Skip to content

Commit 051ed55

Browse files
committed
Add conformal quantile regresssion
1 parent 905a6d6 commit 051ed55

File tree

3 files changed

+338
-21
lines changed

3 files changed

+338
-21
lines changed

autoemulate/emulators/conformal.py

Lines changed: 225 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import math
2+
from typing import Literal
23

34
import torch
4-
from torch import Tensor, nn
5+
from torch import nn
56
from torch.optim.lr_scheduler import LRScheduler
67

78
from autoemulate.core.device import TorchDeviceMixin
@@ -10,6 +11,69 @@
1011
from autoemulate.emulators.nn.mlp import MLP
1112

1213

14+
class QuantileLoss(nn.Module):
15+
"""Quantile loss for quantile regression.
16+
17+
This loss function asymmetrically penalizes over- and under-predictions, enabling
18+
the model to learn specific quantiles of the conditional distribution.
19+
"""
20+
21+
def __init__(self, quantile: float):
22+
"""Initialize quantile loss.
23+
24+
Parameters
25+
----------
26+
quantile: float
27+
Target quantile level in (0, 1). For example, 0.1 for 10th percentile, 0.5
28+
for median, 0.9 for 90th percentile.
29+
"""
30+
super().__init__()
31+
if not 0 < quantile < 1:
32+
msg = f"Quantile must be in (0, 1), got {quantile}"
33+
raise ValueError(msg)
34+
self.quantile = quantile
35+
36+
def forward(self, y_pred: TensorLike, y_true: TensorLike) -> TensorLike:
37+
"""Compute quantile loss.
38+
39+
Parameters
40+
----------
41+
y_pred: TensorLike
42+
Predicted values.
43+
y_true: TensorLike
44+
True target values.
45+
46+
Returns
47+
-------
48+
TensorLike
49+
Scalar loss value.
50+
"""
51+
errors = y_true - y_pred
52+
return torch.max(self.quantile * errors, (self.quantile - 1) * errors).mean()
53+
54+
55+
class QuantileMLP(MLP):
56+
"""MLP with quantile loss for quantile regression."""
57+
58+
def __init__(self, quantile: float, **kwargs):
59+
"""Initialize quantile MLP.
60+
61+
Parameters
62+
----------
63+
quantile: float
64+
Target quantile level in (0, 1).
65+
**kwargs
66+
Keyword arguments passed to MLP parent class.
67+
"""
68+
super().__init__(**kwargs)
69+
self.quantile = quantile
70+
self.quantile_loss = QuantileLoss(quantile)
71+
72+
def loss_func(self, y_pred, y_true):
73+
"""Quantile loss function."""
74+
return self.quantile_loss(y_pred, y_true)
75+
76+
1377
class Conformal(Emulator):
1478
"""Conformal Uncertainty Quantification (UQ) wrapper for emulators.
1579
@@ -26,6 +90,8 @@ def __init__(
2690
device: DeviceLike | None = None,
2791
calibration_ratio: float = 0.2,
2892
n_samples: int = 1000,
93+
method: Literal["split", "quantile"] = "split",
94+
quantile_emulator_kwargs: dict | None = None,
2995
):
3096
"""Initialize a conformal emulator.
3197
@@ -42,8 +108,16 @@ def __init__(
42108
Fraction of the training data to reserve for calibration if explicit
43109
validation data is not provided. Must lie in (0, 1). Defaults to 0.2.
44110
n_samples: int
45-
Number of samples used for sampling-based predictions or
46-
internal procedures. Defaults to 1000.
111+
Number of samples used for sampling-based predictions or internal
112+
procedures. Defaults to 1000.
113+
method: Literal["split", "quantile"]
114+
Conformalization method to use:
115+
- "split": Standard split conformal with constant-width intervals
116+
- "quantile": Conformalized Quantile Regression (CQR) with input-dependent
117+
intervals. Defaults to "split".
118+
quantile_emulator_kwargs: dict | None
119+
Additional keyword arguments for the quantile emulators when
120+
method="quantile". Defaults to None.
47121
"""
48122
self.emulator = emulator
49123
self.supports_grad = emulator.supports_grad
@@ -53,9 +127,14 @@ def __init__(
53127
if not 0 < calibration_ratio < 1:
54128
msg = "Calibration ratio must lie strictly between 0 and 1."
55129
raise ValueError(msg)
130+
if method not in {"split", "quantile"}:
131+
msg = f"Method must be 'split' or 'quantile', got '{method}'."
132+
raise ValueError(msg)
56133
self.alpha = alpha # desired predictive coverage (e.g., 0.95)
57134
self.calibration_ratio = calibration_ratio
58135
self.n_samples = n_samples
136+
self.method = method
137+
self.quantile_emulator_kwargs = quantile_emulator_kwargs or {}
59138
TorchDeviceMixin.__init__(self, device=device)
60139
self.supports_grad = emulator.supports_grad
61140

@@ -98,36 +177,129 @@ def _fit(
98177
else:
99178
x_cal, y_true_cal = validation_data
100179

180+
# Fit the base emulator
101181
self.emulator.fit(x_train, y_train, validation_data=None)
102182

103-
with torch.no_grad():
104-
n_cal = x_cal.shape[0]
105-
# Check calibration data is non-empty
106-
if n_cal == 0:
107-
msg = "Calibration set must contain at least one sample."
108-
raise ValueError(msg)
183+
n_cal = x_cal.shape[0]
184+
# Check calibration data is non-empty
185+
if n_cal == 0:
186+
msg = "Calibration set must contain at least one sample."
187+
raise ValueError(msg)
109188

189+
with torch.no_grad():
110190
# Predict and calculate residuals
111191
y_pred_cal = self.output_to_tensor(self.emulator.predict(x_cal))
192+
193+
if self.method == "split":
194+
# Standard split conformal: compute global quantile of residuals
112195
residuals = torch.abs(y_true_cal - y_pred_cal)
113196

114-
# Apply finite-sample correction to quantile level to ensure valid coverage
197+
# Apply finite-sample correction to quantile level
115198
quantile_level = min(1.0, math.ceil((n_cal + 1) * self.alpha) / n_cal)
116199

117-
# Calibrate over the batch dim with a separate quantile for each output
200+
# Calibrate over the batch dim with a separate quantile per output
118201
self.q = torch.quantile(residuals, quantile_level, dim=0)
119202

203+
elif self.method == "quantile":
204+
# Conformalized Quantile Regression: train quantile regressors
205+
self._fit_quantile_regressors(x_train, y_train, x_cal, y_true_cal)
206+
120207
self.is_fitted_ = True
121208

122-
def _predict(self, x: Tensor, with_grad: bool) -> DistributionLike:
123-
pred = self.emulator.predict(x, with_grad)
124-
mean = self.output_to_tensor(pred)
125-
q = self.q.to(mean.device)
126-
return torch.distributions.Independent(
127-
torch.distributions.Uniform(mean - q, mean + q),
128-
reinterpreted_batch_ndims=mean.ndim - 1,
209+
def _fit_quantile_regressors(
210+
self,
211+
x_train: TensorLike,
212+
y_train: TensorLike,
213+
x_cal: TensorLike,
214+
y_true_cal: TensorLike,
215+
) -> None:
216+
"""Fit quantile regressors for CQR method.
217+
218+
Trains two quantile regressors to predict lower and upper quantiles,
219+
then calibrates the width using the calibration set.
220+
"""
221+
# Calculate quantile levels
222+
lower_q = (1 - self.alpha) / 2
223+
upper_q = 1 - lower_q
224+
225+
# Create quantile regression emulators
226+
mlp_kwargs = {
227+
"epochs": 100,
228+
"batch_size": 16,
229+
"lr": 1e-2,
230+
**self.quantile_emulator_kwargs,
231+
}
232+
233+
# Lower quantile emulator
234+
self.lower_quantile_emulator = QuantileMLP(
235+
lower_q, x=x_train, y=y_train, device=self.device, **mlp_kwargs
129236
)
130237

238+
# Upper quantile emulator
239+
self.upper_quantile_emulator = QuantileMLP(
240+
upper_q, x=x_train, y=y_train, device=self.device, **mlp_kwargs
241+
)
242+
243+
# Fit the quantile emulators
244+
self.lower_quantile_emulator.fit(x_train, y_train, validation_data=None)
245+
self.upper_quantile_emulator.fit(x_train, y_train, validation_data=None)
246+
247+
# Predict quantiles on calibration set
248+
with torch.no_grad():
249+
lower_pred_cal = self.output_to_tensor(
250+
self.lower_quantile_emulator.predict(x_cal)
251+
)
252+
upper_pred_cal = self.output_to_tensor(
253+
self.upper_quantile_emulator.predict(x_cal)
254+
)
255+
256+
# Calculate conformalization scores (non-conformity scores)
257+
# For CQR, the score is max(lower - y, y - upper)
258+
scores = torch.maximum(
259+
lower_pred_cal - y_true_cal, y_true_cal - upper_pred_cal
260+
)
261+
262+
# Apply finite-sample correction
263+
n_cal = x_cal.shape[0]
264+
quantile_level = min(1.0, math.ceil((n_cal + 1) * self.alpha) / n_cal)
265+
266+
# Compute the correction term per output dimension
267+
self.q_cqr = torch.quantile(scores, quantile_level, dim=0)
268+
269+
def _predict(self, x: TensorLike, with_grad: bool) -> DistributionLike:
270+
if self.method == "split":
271+
# Standard split conformal: constant-width intervals
272+
pred = self.emulator.predict(x, with_grad)
273+
mean = self.output_to_tensor(pred)
274+
q = self.q.to(mean.device)
275+
return torch.distributions.Independent(
276+
torch.distributions.Uniform(mean - q, mean + q),
277+
reinterpreted_batch_ndims=mean.ndim - 1,
278+
)
279+
280+
if self.method == "quantile":
281+
# CQR: input-dependent intervals
282+
lower_pred = self.output_to_tensor(
283+
self.lower_quantile_emulator.predict(x, with_grad)
284+
)
285+
upper_pred = self.output_to_tensor(
286+
self.upper_quantile_emulator.predict(x, with_grad)
287+
)
288+
q_cqr = self.q_cqr.to(lower_pred.device)
289+
290+
# Apply calibration correction
291+
lower_bound = lower_pred - q_cqr
292+
upper_bound = upper_pred + q_cqr
293+
294+
# Return uniform distribution over the calibrated interval
295+
return torch.distributions.Independent(
296+
torch.distributions.Uniform(lower_bound, upper_bound),
297+
reinterpreted_batch_ndims=lower_bound.ndim - 1,
298+
)
299+
300+
msg = f"Unknown method: {self.method}"
301+
raise ValueError(msg)
302+
131303

132304
class ConformalMLP(Conformal, PyTorchBackend):
133305
"""Conformal UQ with an MLP.
@@ -146,6 +318,7 @@ def __init__(
146318
device: DeviceLike | None = None,
147319
alpha: float = 0.95,
148320
calibration_ratio: float = 0.2,
321+
method: Literal["split", "quantile"] = "split",
149322
activation_cls: type[nn.Module] = nn.ReLU,
150323
loss_fn_cls: type[nn.Module] = nn.MSELoss,
151324
epochs: int = 100,
@@ -160,6 +333,7 @@ def __init__(
160333
random_seed: int | None = None,
161334
scheduler_cls: type[LRScheduler] | None = None,
162335
scheduler_params: dict | None = None,
336+
quantile_emulator_kwargs: dict | None = None,
163337
):
164338
"""
165339
Initialize an ensemble of MLPs.
@@ -181,6 +355,11 @@ def __init__(
181355
calibration_ratio: float
182356
Fraction of training samples to hold out for calibration when an explicit
183357
validation set is not provided.
358+
method: Literal["split", "quantile"]
359+
Conformalization method:
360+
- "split": Standard split conformal (constant-width intervals)
361+
- "quantile": Conformalized Quantile Regression (input-dependent intervals)
362+
Defaults to "split".
184363
activation_cls: type[nn.Module]
185364
Activation function to use in the hidden layers. Defaults to `nn.ReLU`.
186365
loss_fn_cls: type[nn.Module]
@@ -218,6 +397,9 @@ def __init__(
218397
None.
219398
scheduler_params: dict | None
220399
Additional keyword arguments related to the scheduler.
400+
quantile_emulator_kwargs: dict | None
401+
Additional keyword arguments for the quantile emulators when
402+
method="quantile". Defaults to None.
221403
"""
222404
nn.Module.__init__(self)
223405

@@ -242,12 +424,37 @@ def __init__(
242424
scheduler_cls=scheduler_cls,
243425
scheduler_params=scheduler_params,
244426
)
427+
428+
quantile_defaults = {
429+
"standardize_x": standardize_x,
430+
"standardize_y": standardize_y,
431+
"activation_cls": activation_cls,
432+
"loss_fn_cls": loss_fn_cls,
433+
"epochs": epochs,
434+
"batch_size": batch_size,
435+
"layer_dims": layer_dims,
436+
"weight_init": weight_init,
437+
"scale": scale,
438+
"bias_init": bias_init,
439+
"dropout_prob": dropout_prob,
440+
"lr": lr,
441+
"params_size": params_size,
442+
"random_seed": random_seed,
443+
"scheduler_cls": scheduler_cls,
444+
"scheduler_params": scheduler_params,
445+
}
446+
merged_quantile_kwargs = {
447+
**quantile_defaults,
448+
**(quantile_emulator_kwargs or {}),
449+
}
245450
Conformal.__init__(
246451
self,
247452
emulator=emulator,
248453
alpha=alpha,
249454
device=device,
250455
calibration_ratio=calibration_ratio,
456+
method=method,
457+
quantile_emulator_kwargs=merged_quantile_kwargs,
251458
)
252459

253460
@staticmethod

0 commit comments

Comments
 (0)