-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
Is your feature request related to a problem? Please describe.
I don't like that the repaint scheduler implements the timestep_spacing leading method by default, with no way to control it. The "leading" method is problematic for the average user: When the number of train timesteps isn't divisible by the num inference steps it produces counter-intuitive results. e.g. with 300 inference & 500 training timesteps, the timesteps array starts from 300 instead of 500. This is rarely desirable
Describe the solution you'd like.
The option to set timestep_spacing like other schedulers
Describe alternatives you've considered.
Currently I'm making a new class, inherit from RePaintScheduler, overwrite the set_timesteps method, and default to trailing:
class RePaintSchedulerTSSpacing(RePaintScheduler):
def set_timesteps(
self,
num_inference_steps: int,
jump_length: int = 10,
jump_n_sample: int = 10,
device: str | torch.device = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
jump_length (`int`, defaults to 10):
The number of steps taken forward in time before going backward in time for a single jump.
jump_n_sample (`int`, defaults to 10):
The number of times to make a forward time jump for a given chosen time sample.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to.
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps
# 1. Generate the correctly spaced base timesteps
if self.config.timestep_spacing == "linspace":
base_timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
base_timesteps = np.round(base_timesteps).astype(np.int64)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
base_timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
steps_offset = getattr(self.config, "steps_offset", 0)
base_timesteps += steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
base_timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
base_timesteps -= 1
base_timesteps = base_timesteps[::-1].copy() # Ensure ascending order for index mapping
else:
raise ValueError(
f"timestep_spacing must be one of ['linspace', 'trailing', 'leading'], got {self.config.timestep_spacing}"
)
# 2. Generate the jumping sequence of indices (0 to num_inference_steps-1)
# This is the original RePaint logic, but on indices instead of final timesteps
jump_indices = []
jumps = {}
for j in range(0, num_inference_steps - jump_length, jump_length):
jumps[j] = jump_n_sample - 1
t_idx = num_inference_steps
while t_idx >= 1:
t_idx = t_idx - 1
jump_indices.append(t_idx)
if jumps.get(t_idx, 0) > 0:
jumps[t_idx] = jumps[t_idx] - 1
for _ in range(jump_length):
t_idx = t_idx + 1
jump_indices.append(t_idx)
# 3. Map the jumping indices to the correctly spaced base_timesteps
timesteps = base_timesteps[jump_indices]
self.timesteps = torch.from_numpy(timesteps).to(device)
I think this should be the fix. Can make a PR if this is considered OK
Additional context.
Visual comparison.
