File tree Expand file tree Collapse file tree 2 files changed +1
-4
lines changed Expand file tree Collapse file tree 2 files changed +1
-4
lines changed Original file line number Diff line number Diff line change @@ -286,19 +286,16 @@ def forward(
286286
287287 sample = self .conv_in (sample )
288288
289- upscale_dtype = next (iter (self .up_blocks .parameters ())).dtype
290289 if torch .is_grad_enabled () and self .gradient_checkpointing :
291290 # middle
292291 sample = self ._gradient_checkpointing_func (self .mid_block , sample , latent_embeds )
293- sample = sample .to (upscale_dtype )
294292
295293 # up
296294 for up_block in self .up_blocks :
297295 sample = self ._gradient_checkpointing_func (up_block , sample , latent_embeds )
298296 else :
299297 # middle
300298 sample = self .mid_block (sample , latent_embeds )
301- sample = sample .to (upscale_dtype )
302299
303300 # up
304301 for up_block in self .up_blocks :
Original file line number Diff line number Diff line change @@ -335,7 +335,7 @@ def from_pretrained(
335335 )
336336 expected_kwargs , optional_kwargs = block_cls ._get_signature_keys (block_cls )
337337 block_kwargs = {
338- name : kwargs .pop (name ) for name in kwargs if name in expected_kwargs or name in optional_kwargs
338+ name : kwargs .get (name ) for name in kwargs if name in expected_kwargs or name in optional_kwargs
339339 }
340340
341341 return block_cls (** block_kwargs )
You can’t perform that action at this time.
0 commit comments