Skip to content

wan 2.2 cause nan in latent in i2v #12613

@chaowenguo0

Description

@chaowenguo0

Describe the bug

import asyncio, aiohttp.web, math, pathlib, torch, builtins, transformers, diffusers, PIL, av, io, numpy, uvloop, concurrent, multiprocessing, os, zipfile, sys
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.hub.file_download import model_file_download
#import sd_embed.embedding_funcs

model_dir = snapshot_download('Wan-AI/Wan2.2-I2V-A14B-Diffusers')
onload_device=torch.device('cuda')
pipe = diffusers.WanImageToVideoPipeline.from_pretrained(model_dir, vae=diffusers.AutoencoderKLWan.from_pretrained(model_dir, subfolder='vae', torch_dtype=torch.float32).to(onload_device),
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
quantization_config=diffusers.quantizers.PipelineQuantizationConfig(quant_backend='bitsandbytes_4bit', quant_kwargs={'load_in_4bit':True, 'bnb_4bit_quant_type':'nf4', 'bnb_4bit_compute_dtype':torch.bfloat16}, components_to_quantize=['transformer', 'transformer_2']))
diffusers.hooks.apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type='leaf_level', use_stream=True, record_stream=True, non_blocking=True, low_cpu_mem_usage=True)
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_type='leaf_level', use_stream=True, record_stream=True, non_blocking=True, low_cpu_mem_usage=True)
pipe.transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe.transformer_2.enable_group_offload(onload_device=onload_device, offload_type='leaf_level', use_stream=True, record_stream=True, non_blocking=True, low_cpu_mem_usage=True)
pipe.transformer_2.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()

async def config(request):
return aiohttp.web.json_response({1:2})

def decode_tensors(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
print(step, latents)
return callback_kwargs

def generate(prompt, negative, image):
with zipfile.ZipFile('output.zip', 'w') as output:
buffer = io.BytesIO()
with av.open(buffer, mode='w', format='mp4') as writer:
stream = writer.add_stream('h264', rate=8)
stream.width = 1280
stream.height = 720
stream.pix_fmt = 'yuv420p'
#prompt_embeds, prompt_neg_embeds = sd_embed.embedding_funcs.get_weighted_text_embeddings_sd15(pipe, prompt=prompt, neg_prompt=negative)
for frame in pipe(image=PIL.Image.open(io.BytesIO(image)).convert('RGB').resize((stream.width, stream.height)), prompt=prompt, negative_prompt=negative, width=stream.width, height=stream.height, num_frames=81, num_inference_steps=30, callback_on_step_end=decode_tensors).frames[0]: writer.mux(stream.encode(av.VideoFrame.from_ndarray((frame * 255).astype(numpy.uint8))))
writer.mux(stream.encode())
output.writestr('output.mp4', buffer.getvalue())

Reproduction

see above

Logs

System Info

ubuntu22.04-cuda12.4.0-py311-torch2.8.0 diffusers 0.35.2

Who can help?

@yiyixuxu @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions