@@ -541,9 +541,6 @@ def create_conformal_subclass(
541541 alpha : float = fixed_kwargs .get ("alpha" , 0.95 )
542542 calibration_ratio : float = fixed_kwargs .get ("calibration_ratio" , 0.2 )
543543 quantile_emulator_kwargs : dict | None = fixed_kwargs .get ("quantile_emulator_kwargs" )
544- epochs = fixed_kwargs .get ("epochs" , 50 )
545- lr = fixed_kwargs .get ("lr" , 2e-1 )
546- device = fixed_kwargs .get ("device" )
547544
548545 class ConformalMLPSubclass (conformal_mlp_base_class ):
549546 def __init__ (
@@ -603,19 +600,18 @@ def get_tune_params():
603600 """Get tunable parameters, excluding those that are fixed."""
604601 tune_params = conformal_mlp_base_class .get_tune_params ()
605602 # Remove fixed parameters from tuning
606- tune_params .pop ("mean_module_fn" , None )
607- tune_params .pop ("covar_module_fn" , None )
603+ tune_params .pop ("method" , None )
608604 for key in fixed_kwargs :
609605 tune_params .pop (key , None )
610606 return tune_params
611607
612608 # Create a more descriptive docstring that includes fixed parameters
613- mean_covar_and_fixed_kwargs = {
609+ method_and_fixed_kwargs = {
614610 ** fixed_kwargs ,
615611 }
616612 fixed_params_str = "\n " .join (
617613 f"- { k } = { v .__name__ if callable (v ) else v } "
618- for k , v in mean_covar_and_fixed_kwargs .items ()
614+ for k , v in method_and_fixed_kwargs .items ()
619615 )
620616
621617 ConformalMLPSubclass .__doc__ = f"""
0 commit comments