11import math
2+ from typing import Literal
23
34import torch
4- from torch import Tensor , nn
5+ from torch import nn
56from torch .optim .lr_scheduler import LRScheduler
67
78from autoemulate .core .device import TorchDeviceMixin
1011from autoemulate .emulators .nn .mlp import MLP
1112
1213
14+ class QuantileLoss (nn .Module ):
15+ """Quantile loss for quantile regression.
16+
17+ This loss function asymmetrically penalizes over- and under-predictions, enabling
18+ the model to learn specific quantiles of the conditional distribution.
19+ """
20+
21+ def __init__ (self , quantile : float ):
22+ """Initialize quantile loss.
23+
24+ Parameters
25+ ----------
26+ quantile: float
27+ Target quantile level in (0, 1). For example, 0.1 for 10th percentile, 0.5
28+ for median, 0.9 for 90th percentile.
29+ """
30+ super ().__init__ ()
31+ if not 0 < quantile < 1 :
32+ msg = f"Quantile must be in (0, 1), got { quantile } "
33+ raise ValueError (msg )
34+ self .quantile = quantile
35+
36+ def forward (self , y_pred : TensorLike , y_true : TensorLike ) -> TensorLike :
37+ """Compute quantile loss.
38+
39+ Parameters
40+ ----------
41+ y_pred: TensorLike
42+ Predicted values.
43+ y_true: TensorLike
44+ True target values.
45+
46+ Returns
47+ -------
48+ TensorLike
49+ Scalar loss value.
50+ """
51+ errors = y_true - y_pred
52+ return torch .max (self .quantile * errors , (self .quantile - 1 ) * errors ).mean ()
53+
54+
55+ class QuantileMLP (MLP ):
56+ """MLP with quantile loss for quantile regression."""
57+
58+ def __init__ (self , quantile : float , ** kwargs ):
59+ """Initialize quantile MLP.
60+
61+ Parameters
62+ ----------
63+ quantile: float
64+ Target quantile level in (0, 1).
65+ **kwargs
66+ Keyword arguments passed to MLP parent class.
67+ """
68+ super ().__init__ (** kwargs )
69+ self .quantile = quantile
70+ self .quantile_loss = QuantileLoss (quantile )
71+
72+ def loss_func (self , y_pred , y_true ):
73+ """Quantile loss function."""
74+ return self .quantile_loss (y_pred , y_true )
75+
76+
1377class Conformal (Emulator ):
1478 """Conformal Uncertainty Quantification (UQ) wrapper for emulators.
1579
@@ -26,6 +90,8 @@ def __init__(
2690 device : DeviceLike | None = None ,
2791 calibration_ratio : float = 0.2 ,
2892 n_samples : int = 1000 ,
93+ method : Literal ["split" , "quantile" ] = "split" ,
94+ quantile_emulator_kwargs : dict | None = None ,
2995 ):
3096 """Initialize a conformal emulator.
3197
@@ -42,8 +108,16 @@ def __init__(
42108 Fraction of the training data to reserve for calibration if explicit
43109 validation data is not provided. Must lie in (0, 1). Defaults to 0.2.
44110 n_samples: int
45- Number of samples used for sampling-based predictions or
46- internal procedures. Defaults to 1000.
111+ Number of samples used for sampling-based predictions or internal
112+ procedures. Defaults to 1000.
113+ method: Literal["split", "quantile"]
114+ Conformalization method to use:
115+ - "split": Standard split conformal with constant-width intervals
116+ - "quantile": Conformalized Quantile Regression (CQR) with input-dependent
117+ intervals. Defaults to "split".
118+ quantile_emulator_kwargs: dict | None
119+ Additional keyword arguments for the quantile emulators when
120+ method="quantile". Defaults to None.
47121 """
48122 self .emulator = emulator
49123 self .supports_grad = emulator .supports_grad
@@ -53,9 +127,14 @@ def __init__(
53127 if not 0 < calibration_ratio < 1 :
54128 msg = "Calibration ratio must lie strictly between 0 and 1."
55129 raise ValueError (msg )
130+ if method not in {"split" , "quantile" }:
131+ msg = f"Method must be 'split' or 'quantile', got '{ method } '."
132+ raise ValueError (msg )
56133 self .alpha = alpha # desired predictive coverage (e.g., 0.95)
57134 self .calibration_ratio = calibration_ratio
58135 self .n_samples = n_samples
136+ self .method = method
137+ self .quantile_emulator_kwargs = quantile_emulator_kwargs or {}
59138 TorchDeviceMixin .__init__ (self , device = device )
60139 self .supports_grad = emulator .supports_grad
61140
@@ -98,36 +177,129 @@ def _fit(
98177 else :
99178 x_cal , y_true_cal = validation_data
100179
180+ # Fit the base emulator
101181 self .emulator .fit (x_train , y_train , validation_data = None )
102182
103- with torch .no_grad ():
104- n_cal = x_cal .shape [0 ]
105- # Check calibration data is non-empty
106- if n_cal == 0 :
107- msg = "Calibration set must contain at least one sample."
108- raise ValueError (msg )
183+ n_cal = x_cal .shape [0 ]
184+ # Check calibration data is non-empty
185+ if n_cal == 0 :
186+ msg = "Calibration set must contain at least one sample."
187+ raise ValueError (msg )
109188
189+ with torch .no_grad ():
110190 # Predict and calculate residuals
111191 y_pred_cal = self .output_to_tensor (self .emulator .predict (x_cal ))
192+
193+ if self .method == "split" :
194+ # Standard split conformal: compute global quantile of residuals
112195 residuals = torch .abs (y_true_cal - y_pred_cal )
113196
114- # Apply finite-sample correction to quantile level to ensure valid coverage
197+ # Apply finite-sample correction to quantile level
115198 quantile_level = min (1.0 , math .ceil ((n_cal + 1 ) * self .alpha ) / n_cal )
116199
117- # Calibrate over the batch dim with a separate quantile for each output
200+ # Calibrate over the batch dim with a separate quantile per output
118201 self .q = torch .quantile (residuals , quantile_level , dim = 0 )
119202
203+ elif self .method == "quantile" :
204+ # Conformalized Quantile Regression: train quantile regressors
205+ self ._fit_quantile_regressors (x_train , y_train , x_cal , y_true_cal )
206+
120207 self .is_fitted_ = True
121208
122- def _predict (self , x : Tensor , with_grad : bool ) -> DistributionLike :
123- pred = self .emulator .predict (x , with_grad )
124- mean = self .output_to_tensor (pred )
125- q = self .q .to (mean .device )
126- return torch .distributions .Independent (
127- torch .distributions .Uniform (mean - q , mean + q ),
128- reinterpreted_batch_ndims = mean .ndim - 1 ,
209+ def _fit_quantile_regressors (
210+ self ,
211+ x_train : TensorLike ,
212+ y_train : TensorLike ,
213+ x_cal : TensorLike ,
214+ y_true_cal : TensorLike ,
215+ ) -> None :
216+ """Fit quantile regressors for CQR method.
217+
218+ Trains two quantile regressors to predict lower and upper quantiles,
219+ then calibrates the width using the calibration set.
220+ """
221+ # Calculate quantile levels
222+ lower_q = (1 - self .alpha ) / 2
223+ upper_q = 1 - lower_q
224+
225+ # Create quantile regression emulators
226+ mlp_kwargs = {
227+ "epochs" : 100 ,
228+ "batch_size" : 16 ,
229+ "lr" : 1e-2 ,
230+ ** self .quantile_emulator_kwargs ,
231+ }
232+
233+ # Lower quantile emulator
234+ self .lower_quantile_emulator = QuantileMLP (
235+ lower_q , x = x_train , y = y_train , device = self .device , ** mlp_kwargs
129236 )
130237
238+ # Upper quantile emulator
239+ self .upper_quantile_emulator = QuantileMLP (
240+ upper_q , x = x_train , y = y_train , device = self .device , ** mlp_kwargs
241+ )
242+
243+ # Fit the quantile emulators
244+ self .lower_quantile_emulator .fit (x_train , y_train , validation_data = None )
245+ self .upper_quantile_emulator .fit (x_train , y_train , validation_data = None )
246+
247+ # Predict quantiles on calibration set
248+ with torch .no_grad ():
249+ lower_pred_cal = self .output_to_tensor (
250+ self .lower_quantile_emulator .predict (x_cal )
251+ )
252+ upper_pred_cal = self .output_to_tensor (
253+ self .upper_quantile_emulator .predict (x_cal )
254+ )
255+
256+ # Calculate conformalization scores (non-conformity scores)
257+ # For CQR, the score is max(lower - y, y - upper)
258+ scores = torch .maximum (
259+ lower_pred_cal - y_true_cal , y_true_cal - upper_pred_cal
260+ )
261+
262+ # Apply finite-sample correction
263+ n_cal = x_cal .shape [0 ]
264+ quantile_level = min (1.0 , math .ceil ((n_cal + 1 ) * self .alpha ) / n_cal )
265+
266+ # Compute the correction term per output dimension
267+ self .q_cqr = torch .quantile (scores , quantile_level , dim = 0 )
268+
269+ def _predict (self , x : TensorLike , with_grad : bool ) -> DistributionLike :
270+ if self .method == "split" :
271+ # Standard split conformal: constant-width intervals
272+ pred = self .emulator .predict (x , with_grad )
273+ mean = self .output_to_tensor (pred )
274+ q = self .q .to (mean .device )
275+ return torch .distributions .Independent (
276+ torch .distributions .Uniform (mean - q , mean + q ),
277+ reinterpreted_batch_ndims = mean .ndim - 1 ,
278+ )
279+
280+ if self .method == "quantile" :
281+ # CQR: input-dependent intervals
282+ lower_pred = self .output_to_tensor (
283+ self .lower_quantile_emulator .predict (x , with_grad )
284+ )
285+ upper_pred = self .output_to_tensor (
286+ self .upper_quantile_emulator .predict (x , with_grad )
287+ )
288+ q_cqr = self .q_cqr .to (lower_pred .device )
289+
290+ # Apply calibration correction
291+ lower_bound = lower_pred - q_cqr
292+ upper_bound = upper_pred + q_cqr
293+
294+ # Return uniform distribution over the calibrated interval
295+ return torch .distributions .Independent (
296+ torch .distributions .Uniform (lower_bound , upper_bound ),
297+ reinterpreted_batch_ndims = lower_bound .ndim - 1 ,
298+ )
299+
300+ msg = f"Unknown method: { self .method } "
301+ raise ValueError (msg )
302+
131303
132304class ConformalMLP (Conformal , PyTorchBackend ):
133305 """Conformal UQ with an MLP.
@@ -146,6 +318,7 @@ def __init__(
146318 device : DeviceLike | None = None ,
147319 alpha : float = 0.95 ,
148320 calibration_ratio : float = 0.2 ,
321+ method : Literal ["split" , "quantile" ] = "split" ,
149322 activation_cls : type [nn .Module ] = nn .ReLU ,
150323 loss_fn_cls : type [nn .Module ] = nn .MSELoss ,
151324 epochs : int = 100 ,
@@ -160,6 +333,7 @@ def __init__(
160333 random_seed : int | None = None ,
161334 scheduler_cls : type [LRScheduler ] | None = None ,
162335 scheduler_params : dict | None = None ,
336+ quantile_emulator_kwargs : dict | None = None ,
163337 ):
164338 """
165339 Initialize an ensemble of MLPs.
@@ -181,6 +355,11 @@ def __init__(
181355 calibration_ratio: float
182356 Fraction of training samples to hold out for calibration when an explicit
183357 validation set is not provided.
358+ method: Literal["split", "quantile"]
359+ Conformalization method:
360+ - "split": Standard split conformal (constant-width intervals)
361+ - "quantile": Conformalized Quantile Regression (input-dependent intervals)
362+ Defaults to "split".
184363 activation_cls: type[nn.Module]
185364 Activation function to use in the hidden layers. Defaults to `nn.ReLU`.
186365 loss_fn_cls: type[nn.Module]
@@ -218,6 +397,9 @@ def __init__(
218397 None.
219398 scheduler_params: dict | None
220399 Additional keyword arguments related to the scheduler.
400+ quantile_emulator_kwargs: dict | None
401+ Additional keyword arguments for the quantile emulators when
402+ method="quantile". Defaults to None.
221403 """
222404 nn .Module .__init__ (self )
223405
@@ -242,12 +424,37 @@ def __init__(
242424 scheduler_cls = scheduler_cls ,
243425 scheduler_params = scheduler_params ,
244426 )
427+
428+ quantile_defaults = {
429+ "standardize_x" : standardize_x ,
430+ "standardize_y" : standardize_y ,
431+ "activation_cls" : activation_cls ,
432+ "loss_fn_cls" : loss_fn_cls ,
433+ "epochs" : epochs ,
434+ "batch_size" : batch_size ,
435+ "layer_dims" : layer_dims ,
436+ "weight_init" : weight_init ,
437+ "scale" : scale ,
438+ "bias_init" : bias_init ,
439+ "dropout_prob" : dropout_prob ,
440+ "lr" : lr ,
441+ "params_size" : params_size ,
442+ "random_seed" : random_seed ,
443+ "scheduler_cls" : scheduler_cls ,
444+ "scheduler_params" : scheduler_params ,
445+ }
446+ merged_quantile_kwargs = {
447+ ** quantile_defaults ,
448+ ** (quantile_emulator_kwargs or {}),
449+ }
245450 Conformal .__init__ (
246451 self ,
247452 emulator = emulator ,
248453 alpha = alpha ,
249454 device = device ,
250455 calibration_ratio = calibration_ratio ,
456+ method = method ,
457+ quantile_emulator_kwargs = merged_quantile_kwargs ,
251458 )
252459
253460 @staticmethod
0 commit comments