From bc044d91f7f03172f25ab5e9ea479dd1beea943f Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Fri, 21 Jun 2024 08:48:56 -0400 Subject: [PATCH 01/13] add in interrupt callback --- .../pipelines/pixart_alpha/pipeline_pixart_sigma.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 69f028914774..4409ec1fb86f 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -605,7 +605,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - + + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -744,7 +748,8 @@ def __call__( prompt_attention_mask, negative_prompt_attention_mask, ) - + self._interrupt = False + # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -812,6 +817,8 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -859,7 +866,7 @@ def __call__( progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) + callback(self, step_idx, t, latents) #Not 100% sure if this will break anything. Callback documentation would need to be updated to to reflect the added input if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From d5082184ab84e95ca55c78b46989956f0105fefb Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sat, 22 Jun 2024 22:59:22 -0400 Subject: [PATCH 02/13] Callbacks should be working now Had to change some things around vs how the original file was, in order to get the callbacks to work correctly. I tried to base a lot of the layout from newer diffusers like SD3, --- .../pixart_alpha/pipeline_pixart_sigma.py | 200 +++++++++++------- 1 file changed, 125 insertions(+), 75 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 4409ec1fb86f..33ebb363ada5 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -16,11 +16,12 @@ import inspect import re import urllib.parse as ul -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import T5EncoderModel, T5Tokenizer +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor from ...models import AutoencoderKL, PixArtTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers @@ -196,7 +197,8 @@ class PixArtSigmaPipeline(DiffusionPipeline): _optional_components = ["tokenizer", "text_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" - + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + def __init__( self, tokenizer: T5Tokenizer, @@ -208,25 +210,38 @@ def __init__( super().__init__() self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + ) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 512 #from the actual tokenizer_config.json file + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 #from config.json + ) - # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 def encode_prompt( self, prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, - negative_prompt: str = "", - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, clean_caption: bool = False, + do_classifier_free_guidance: bool = True, max_sequence_length: int = 300, **kwargs, ): @@ -256,30 +271,23 @@ def encode_prompt( If `True`, the function will preprocess and clean the provided caption before encoding. max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. """ - if "mask_feature" in kwargs: deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + device = device or self._execution_device - if device is None: - device = self._execution_device - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - # See Section 3.1. of the paper. - max_length = max_sequence_length - if prompt_embeds is None: - prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", - max_length=max_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", @@ -326,7 +334,7 @@ def encode_prompt( uncond_input = self.tokenizer( uncond_tokens, padding="max_length", - max_length=max_length, + max_length=max_sequence_length, truncation=True, return_attention_mask=True, add_special_tokens=True, @@ -381,22 +389,21 @@ def check_inputs( prompt, height, width, - negative_prompt, - callback_steps, + negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -442,7 +449,7 @@ def check_inputs( f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) - + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): @@ -584,7 +591,20 @@ def _clean_caption(self, caption): return caption.strip() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + shape = ( batch_size, num_channels_latents, @@ -605,40 +625,55 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + @property def interrupt(self): return self._interrupt - + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, - negative_prompt: str = "", - num_inference_steps: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, timesteps: List[int] = None, sigmas: List[float] = None, + eta: float = 0.0, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, - height: Optional[int] = None, - width: Optional[int] = None, - eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, clean_caption: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], use_resolution_binning: bool = True, max_sequence_length: int = 300, **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: + ): """ Function invoked when calling the pipeline for generation. @@ -716,13 +751,18 @@ def __call__( Examples: Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: + [`~pipelines.ImagePipelineOutput`]: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: if self.transformer.config.sample_size == 256: aspect_ratio_bin = ASPECT_RATIO_2048_BIN @@ -741,14 +781,13 @@ def __call__( prompt, height, width, - negative_prompt, - callback_steps, - prompt_embeds, - negative_prompt_embeds, - prompt_attention_mask, - negative_prompt_attention_mask, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) - self._interrupt = False # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): @@ -757,14 +796,12 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - + + self._guidance_scale = guidance_scale + + self._interrupt = False device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt ( prompt_embeds, @@ -772,26 +809,27 @@ def __call__( negative_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, + prompt=prompt, negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, + do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + device=device, clean_caption=clean_caption, + num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) - if do_classifier_free_guidance: + + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels @@ -813,13 +851,12 @@ def __call__( added_cond_kwargs = {"resolution": None, "aspect_ratio": None} # 7. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) current_timestep = t @@ -848,9 +885,9 @@ def __call__( )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: @@ -859,14 +896,27 @@ def __call__( noise_pred = noise_pred # compute previous image: x_t -> x_t-1 + latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(self, step_idx, t, latents) #Not 100% sure if this will break anything. Callback documentation would need to be updated to to reflect the added input if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From e6c0b05d6d8832b0cf18b113887f9238bf1de4d8 Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sat, 22 Jun 2024 23:17:09 -0400 Subject: [PATCH 03/13] corrected a few variables --- .../pipelines/pixart_alpha/pipeline_pixart_sigma.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 33ebb363ada5..d502aa2176e9 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -660,10 +660,10 @@ def __call__( guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, From a8ba299992e9c008aceb75fa67fe309dbfb4c3fe Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sat, 22 Jun 2024 23:43:01 -0400 Subject: [PATCH 04/13] missed a couple more --- src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index d502aa2176e9..09adc41db343 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -234,8 +234,8 @@ def encode_prompt( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, From c1f2577a06c2103889209f988efcc26b16abc617 Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sun, 23 Jun 2024 12:19:56 -0400 Subject: [PATCH 05/13] forgot to put that back in --- src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 09adc41db343..e29a92023b5f 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -284,6 +284,7 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", From 47b327e705f922093895bafe2fc85082ff663860 Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sun, 23 Jun 2024 12:29:31 -0400 Subject: [PATCH 06/13] more minor fixes --- .../pipelines/pixart_alpha/pipeline_pixart_sigma.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index e29a92023b5f..a5482b91faca 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -618,10 +618,7 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma From 811f5d5e5f0a7438401b3697053b4cc7bfdb9dea Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Fri, 28 Jun 2024 09:34:14 -0400 Subject: [PATCH 07/13] implement cfg rescaling to have parity with common pipelines based on (https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 --- .../pixart_alpha/pipeline_pixart_sigma.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index a5482b91faca..db11c71e6b6f 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -172,7 +172,20 @@ def retrieve_timesteps( scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps - + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg class PixArtSigmaPipeline(DiffusionPipeline): r""" @@ -627,7 +640,11 @@ def prepare_latents( @property def guidance_scale(self): return self._guidance_scale - + + @property + def guidance_rescale(self): + return self._guidance_rescale + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -666,6 +683,7 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, clean_caption: bool = True, + guidance_rescale: float = 0.0, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], use_resolution_binning: bool = True, @@ -740,6 +758,9 @@ def __call__( Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 use_resolution_binning (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to @@ -796,6 +817,7 @@ def __call__( batch_size = prompt_embeds.shape[0] self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale self._interrupt = False device = self._execution_device @@ -886,6 +908,10 @@ def __call__( if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: From 4332af43e764f23cb50bfa10cd6e50bd14a7084c Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Fri, 28 Jun 2024 15:35:32 -0400 Subject: [PATCH 08/13] slight reorder and reword --- .../pipelines/pixart_alpha/pipeline_pixart_sigma.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index db11c71e6b6f..2fa662a9590c 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -616,9 +616,6 @@ def prepare_latents( generator, latents=None, ): - if latents is not None: - return latents.to(device=device, dtype=dtype) - shape = ( batch_size, num_channels_latents, @@ -630,8 +627,11 @@ def prepare_latents( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma From d6f0aabe687659edfa68fdd2eacb2f93a6750b10 Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Fri, 28 Jun 2024 21:24:38 -0400 Subject: [PATCH 09/13] readd legacy callback/callback_steps functionality The lecay callback will require (self, step_idx, t, latents), but has 1:1 parity with the newer callback_on_step_end method. I also included a deprecation warning and an error if both are used at the same time. --- .../pixart_alpha/pipeline_pixart_sigma.py | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 2fa662a9590c..3477df132fc1 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -16,7 +16,7 @@ import inspect import re import urllib.parse as ul -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import torch from transformers import T5EncoderModel, T5Tokenizer @@ -404,6 +404,7 @@ def check_inputs( height, width, negative_prompt=None, + callback_steps=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, @@ -413,6 +414,14 @@ def check_inputs( if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -684,12 +693,14 @@ def __call__( return_dict: bool = True, clean_caption: bool = True, guidance_rescale: float = 0.0, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, # maybe add in logic so that if legacy and new are used at the same time, new takes priority? for now, it will error asking to use one or the other + callback_steps: int = 1, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], use_resolution_binning: bool = True, max_sequence_length: int = 300, **kwargs, - ): + ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -718,6 +729,9 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): @@ -750,17 +764,24 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + called with the following arguments: `callback(self, step: int, timestep: int, latents: torch.Tensor)`. + This feature will be deprecated soon, use callback_on_step_end instead. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. + called at every step. This feature will be deprecated soon, use callback_on_step_end instead. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. - guidance_rescale (`float`, *optional*, defaults to 0.0): - Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise - Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 use_resolution_binning (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to @@ -770,10 +791,18 @@ def __call__( Examples: Returns: - [`~pipelines.ImagePipelineOutput`]: + [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ + if callback is not None: + deprecation_message = "The use of `callback` will soon be deprecated and will be removed in a future version. It is recommended to use `callback_on_step_end` instead." + deprecate("callback", "1.0.0", deprecation_message, standard_warn=False) + + if (callback is not None) and (callback_on_step_end is not None): + raise ValueError( + f"Cannot use both `callback`(will soon be deprecated) and `callback_on_step_end` at the same time. Use one or the other." + ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -805,6 +834,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_steps=callback_steps, # will be deprecated soon callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) @@ -938,9 +968,12 @@ def __call__( prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - # call the callback, if provided + # call the callback, if provided. WILL BE DEPRECATED SOON, USE callback_on_step_end instead if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(self, step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From a6baf665fa64f5152cf349ce54eb88307219fbf4 Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sat, 29 Jun 2024 10:20:20 -0400 Subject: [PATCH 10/13] revert negative_prompt changes --- .../pipelines/pixart_alpha/pipeline_pixart_sigma.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 3477df132fc1..559952d29c2f 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -246,7 +246,7 @@ def __init__( def encode_prompt( self, prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: str = "", prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, @@ -403,7 +403,7 @@ def check_inputs( prompt, height, width, - negative_prompt=None, + negative_prompt, callback_steps=None, prompt_embeds=None, negative_prompt_embeds=None, @@ -472,7 +472,7 @@ def check_inputs( f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) - + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): @@ -636,7 +636,7 @@ def prepare_latents( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - + if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -674,7 +674,7 @@ def interrupt(self): def __call__( self, prompt: Union[str, List[str]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: str = "", height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 30, @@ -837,7 +837,7 @@ def __call__( callback_steps=callback_steps, # will be deprecated soon callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) - + # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): batch_size = 1 From 6f6576accf260e2e7c11c6a0a19d5ec17ce289b3 Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sat, 29 Jun 2024 10:30:12 -0400 Subject: [PATCH 11/13] unused import --- src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 559952d29c2f..cb9f9c328aea 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -16,7 +16,7 @@ import inspect import re import urllib.parse as ul -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Callable, Dict, List, Optional, Union, Tuple import torch from transformers import T5EncoderModel, T5Tokenizer From bb49215ce90f4289abe0de18508649cb25626929 Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sun, 30 Jun 2024 09:44:34 -0400 Subject: [PATCH 12/13] revert small change in legacy callback Since the original implementation doesn't appear to have the ability to interrupt, I'm just going to roll this back. If people want to interrupt, they need to use the newer method anyways, since the older callback method is deprecating. The legacy callback will still provide step_idx, t and latents, like before. --- src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index cb9f9c328aea..15653a5e4958 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -973,7 +973,7 @@ def __call__( progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) - callback(self, step_idx, t, latents) + callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From 5fad1dc28f792a1a64f9f6ab54ecdb54f0bd4137 Mon Sep 17 00:00:00 2001 From: RandomGitUser321 Date: Sun, 30 Jun 2024 10:10:19 -0400 Subject: [PATCH 13/13] borrowed deprecation message from sdxl --- .../pixart_alpha/pipeline_pixart_sigma.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 15653a5e4958..7997ff2cc7a9 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -795,14 +795,20 @@ def __call__( If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ - if callback is not None: - deprecation_message = "The use of `callback` will soon be deprecated and will be removed in a future version. It is recommended to use `callback_on_step_end` instead." - deprecate("callback", "1.0.0", deprecation_message, standard_warn=False) - if (callback is not None) and (callback_on_step_end is not None): - raise ValueError( - f"Cannot use both `callback`(will soon be deprecated) and `callback_on_step_end` at the same time. Use one or the other." - ) + # from diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs