Skip to content

Commit 7f7d72d

Browse files
ADD update_weights for flux and qwen_image (#168)
* ADD update_weights for flux and qwen_image * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * refactor * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * move update_component function --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e068d10 commit 7f7d72d

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

diffsynth_engine/pipelines/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,20 @@ def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipelin
5151
def from_state_dict(cls, state_dicts: BaseStateDicts, pipeline_config: BaseConfig) -> "BasePipeline":
5252
raise NotImplementedError()
5353

54+
def update_weights(self, state_dicts: BaseStateDicts) -> None:
55+
raise NotImplementedError()
56+
57+
@staticmethod
58+
def update_component(
59+
component: torch.nn.Module,
60+
state_dict: Dict[str, torch.Tensor],
61+
device: str,
62+
dtype: torch.dtype,
63+
) -> None:
64+
if component and state_dict:
65+
component.load_state_dict(state_dict, assign=True)
66+
component.to(device=device, dtype=dtype, non_blocking=True)
67+
5468
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
5569
for lora_path, lora_scale in lora_list:
5670
logger.info(f"loading lora from {lora_path} with scale {lora_scale}")

diffsynth_engine/pipelines/flux_image.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,13 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
573573
pipe.compile()
574574
return pipe
575575

576+
def update_weights(self, state_dicts: FluxStateDicts) -> None:
577+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
578+
self.update_component(self.text_encoder_1, state_dicts.clip, self.config.device, self.config.clip_dtype)
579+
self.update_component(self.text_encoder_2, state_dicts.t5, self.config.device, self.config.t5_dtype)
580+
self.update_component(self.vae_decoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
581+
self.update_component(self.vae_encoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
582+
576583
def compile(self):
577584
self.dit.compile_repeated_blocks(dynamic=True)
578585

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
254254
pipe.compile()
255255
return pipe
256256

257+
def update_weights(self, state_dicts: QwenImageStateDicts) -> None:
258+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
259+
self.update_component(self.encoder, state_dicts.encoder, self.config.device, self.config.encoder_dtype)
260+
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
261+
257262
def compile(self):
258263
self.dit.compile_repeated_blocks(dynamic=True)
259264

0 commit comments

Comments
 (0)