|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import inspect |
16 | 15 | import math |
17 | 16 | from typing import Any, Callable, Dict, List, Optional, Union |
18 | 17 |
|
|
26 | 25 | from ...schedulers import FlowMatchEulerDiscreteScheduler |
27 | 26 | from ...utils import is_torch_xla_available, logging, replace_example_docstring |
28 | 27 | from ...utils.torch_utils import randn_tensor |
29 | | -from ..pipeline_utils import DiffusionPipeline |
| 28 | +from ..pipeline_utils import DiffusionPipeline, calculate_shift, retrieve_latents, retrieve_timesteps |
30 | 29 | from .pipeline_output import QwenImagePipelineOutput |
31 | 30 |
|
32 | 31 |
|
|
67 | 66 | VAE_IMAGE_SIZE = 1024 * 1024 |
68 | 67 |
|
69 | 68 |
|
70 | | -# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift |
71 | | -def calculate_shift( |
72 | | - image_seq_len, |
73 | | - base_seq_len: int = 256, |
74 | | - max_seq_len: int = 4096, |
75 | | - base_shift: float = 0.5, |
76 | | - max_shift: float = 1.15, |
77 | | -): |
78 | | - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
79 | | - b = base_shift - m * base_seq_len |
80 | | - mu = image_seq_len * m + b |
81 | | - return mu |
82 | | - |
83 | | - |
84 | | -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps |
85 | | -def retrieve_timesteps( |
86 | | - scheduler, |
87 | | - num_inference_steps: Optional[int] = None, |
88 | | - device: Optional[Union[str, torch.device]] = None, |
89 | | - timesteps: Optional[List[int]] = None, |
90 | | - sigmas: Optional[List[float]] = None, |
91 | | - **kwargs, |
92 | | -): |
93 | | - r""" |
94 | | - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
95 | | - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
96 | | -
|
97 | | - Args: |
98 | | - scheduler (`SchedulerMixin`): |
99 | | - The scheduler to get timesteps from. |
100 | | - num_inference_steps (`int`): |
101 | | - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
102 | | - must be `None`. |
103 | | - device (`str` or `torch.device`, *optional*): |
104 | | - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
105 | | - timesteps (`List[int]`, *optional*): |
106 | | - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
107 | | - `num_inference_steps` and `sigmas` must be `None`. |
108 | | - sigmas (`List[float]`, *optional*): |
109 | | - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
110 | | - `num_inference_steps` and `timesteps` must be `None`. |
111 | | -
|
112 | | - Returns: |
113 | | - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
114 | | - second element is the number of inference steps. |
115 | | - """ |
116 | | - if timesteps is not None and sigmas is not None: |
117 | | - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
118 | | - if timesteps is not None: |
119 | | - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
120 | | - if not accepts_timesteps: |
121 | | - raise ValueError( |
122 | | - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
123 | | - f" timestep schedules. Please check whether you are using the correct scheduler." |
124 | | - ) |
125 | | - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
126 | | - timesteps = scheduler.timesteps |
127 | | - num_inference_steps = len(timesteps) |
128 | | - elif sigmas is not None: |
129 | | - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
130 | | - if not accept_sigmas: |
131 | | - raise ValueError( |
132 | | - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
133 | | - f" sigmas schedules. Please check whether you are using the correct scheduler." |
134 | | - ) |
135 | | - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
136 | | - timesteps = scheduler.timesteps |
137 | | - num_inference_steps = len(timesteps) |
138 | | - else: |
139 | | - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
140 | | - timesteps = scheduler.timesteps |
141 | | - return timesteps, num_inference_steps |
142 | | - |
143 | | - |
144 | | -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents |
145 | | -def retrieve_latents( |
146 | | - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
147 | | -): |
148 | | - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
149 | | - return encoder_output.latent_dist.sample(generator) |
150 | | - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
151 | | - return encoder_output.latent_dist.mode() |
152 | | - elif hasattr(encoder_output, "latents"): |
153 | | - return encoder_output.latents |
154 | | - else: |
155 | | - raise AttributeError("Could not access latents of provided encoder_output") |
156 | | - |
157 | | - |
158 | 69 | def calculate_dimensions(target_area, ratio): |
159 | 70 | width = math.sqrt(target_area * ratio) |
160 | 71 | height = width / ratio |
|
0 commit comments