Skip to content

Commit d65f857

Browse files
authored
Merge branch 'main' into cp-fix
2 parents 3b12a0b + df8dd77 commit d65f857

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

src/diffusers/models/autoencoders/vae.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,16 @@ def forward(
286286

287287
sample = self.conv_in(sample)
288288

289-
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
290289
if torch.is_grad_enabled() and self.gradient_checkpointing:
291290
# middle
292291
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
293-
sample = sample.to(upscale_dtype)
294292

295293
# up
296294
for up_block in self.up_blocks:
297295
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
298296
else:
299297
# middle
300298
sample = self.mid_block(sample, latent_embeds)
301-
sample = sample.to(upscale_dtype)
302299

303300
# up
304301
for up_block in self.up_blocks:

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def from_pretrained(
335335
)
336336
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
337337
block_kwargs = {
338-
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
338+
name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
339339
}
340340

341341
return block_cls(**block_kwargs)

0 commit comments

Comments
 (0)