@@ -158,6 +158,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
158158        use_karras_sigmas (`bool`, *optional*, defaults to `False`): 
159159            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, 
160160            the sigmas are determined according to a sequence of noise levels {σi}. 
161+         use_exponential_sigmas (`bool`, *optional*, defaults to `False`): 
162+             Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. 
161163        timestep_spacing (`str`, defaults to `"linspace"`): 
162164            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 
163165            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 
@@ -186,6 +188,7 @@ def __init__(
186188        prediction_type : str  =  "epsilon" ,
187189        interpolation_type : str  =  "linear" ,
188190        use_karras_sigmas : Optional [bool ] =  False ,
191+         use_exponential_sigmas : Optional [bool ] =  False ,
189192        sigma_min : Optional [float ] =  None ,
190193        sigma_max : Optional [float ] =  None ,
191194        timestep_spacing : str  =  "linspace" ,
@@ -235,6 +238,7 @@ def __init__(
235238
236239        self .is_scale_input_called  =  False 
237240        self .use_karras_sigmas  =  use_karras_sigmas 
241+         self .use_exponential_sigmas  =  use_exponential_sigmas 
238242
239243        self ._step_index  =  None 
240244        self ._begin_index  =  None 
@@ -332,6 +336,12 @@ def set_timesteps(
332336            raise  ValueError ("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`." )
333337        if  timesteps  is  not None  and  self .config .use_karras_sigmas :
334338            raise  ValueError ("Cannot set `timesteps` with `config.use_karras_sigmas = True`." )
339+         if  timesteps  is  not None  and  self .config .use_exponential_sigmas :
340+             raise  ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
341+         if  self .config .use_exponential_sigmas  and  self .config .use_karras_sigmas :
342+             raise  ValueError (
343+                 "Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`" 
344+             )
335345        if  (
336346            timesteps  is  not None 
337347            and  self .config .timestep_type  ==  "continuous" 
@@ -396,6 +406,10 @@ def set_timesteps(
396406                sigmas  =  self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
397407                timesteps  =  np .array ([self ._sigma_to_t (sigma , log_sigmas ) for  sigma  in  sigmas ])
398408
409+             elif  self .config .use_exponential_sigmas :
410+                 sigmas  =  self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
411+                 timesteps  =  np .array ([self ._sigma_to_t (sigma , log_sigmas ) for  sigma  in  sigmas ])
412+ 
399413            if  self .config .final_sigmas_type  ==  "sigma_min" :
400414                sigma_last  =  ((1  -  self .alphas_cumprod [0 ]) /  self .alphas_cumprod [0 ]) **  0.5 
401415            elif  self .config .final_sigmas_type  ==  "zero" :
@@ -468,6 +482,28 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
468482        sigmas  =  (max_inv_rho  +  ramp  *  (min_inv_rho  -  max_inv_rho )) **  rho 
469483        return  sigmas 
470484
485+     # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26 
486+     def  _convert_to_exponential (self , in_sigmas : torch .Tensor , num_inference_steps : int ) ->  torch .Tensor :
487+         """Constructs an exponential noise schedule.""" 
488+ 
489+         # Hack to make sure that other schedulers which copy this function don't break 
490+         # TODO: Add this logic to the other schedulers 
491+         if  hasattr (self .config , "sigma_min" ):
492+             sigma_min  =  self .config .sigma_min 
493+         else :
494+             sigma_min  =  None 
495+ 
496+         if  hasattr (self .config , "sigma_max" ):
497+             sigma_max  =  self .config .sigma_max 
498+         else :
499+             sigma_max  =  None 
500+ 
501+         sigma_min  =  sigma_min  if  sigma_min  is  not None  else  in_sigmas [- 1 ].item ()
502+         sigma_max  =  sigma_max  if  sigma_max  is  not None  else  in_sigmas [0 ].item ()
503+ 
504+         sigmas  =  torch .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ).exp ()
505+         return  sigmas 
506+ 
471507    def  index_for_timestep (self , timestep , schedule_timesteps = None ):
472508        if  schedule_timesteps  is  None :
473509            schedule_timesteps  =  self .timesteps 
0 commit comments