From 1aafbbc4ffb918f92abf0c2f8bc987fb1cf1c402 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 31 Oct 2025 01:15:02 +0000 Subject: [PATCH 1/4] Update benchmarking for diffusers Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- .../diffusers/quantization/diffusion_trt.py | 86 ++++++++++--------- 1 file changed, 44 insertions(+), 42 deletions(-) diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index ad71fd354..e809099d4 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -49,6 +49,7 @@ } +@torch.inference_mode() def generate_image(pipe, prompt, image_name): seed = 42 image = pipe( @@ -61,56 +62,52 @@ def generate_image(pipe, prompt, image_name): print(f"Image generated saved as {image_name}") -def benchmark_model( - pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype=torch.float16 +@torch.inference_mode() +def benchmark_backbone_standalone( + pipe, num_warmup=10, num_benchmark=100, model_name="flux-dev", ): - """Benchmark the backbone model inference time.""" + """Benchmark the backbone model directly without running the full pipeline.""" backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet - backbone_times = [] + # Generate dummy inputs for the backbone + dummy_inputs, _, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(model_name, backbone) + + # Extract the dict from the tuple and move to cuda + dummy_inputs_dict = { + k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs[0].items() + } + + # Warmup + print(f"Warming up: {num_warmup} iterations") + for _ in tqdm(range(num_warmup), desc="Warmup"): + _ = backbone(**dummy_inputs_dict) + + # Benchmark + torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - def forward_pre_hook(_module, _input): + print(f"Benchmarking: {num_benchmark} iterations") + times = [] + for _ in tqdm(range(num_benchmark), desc="Benchmark"): start_event.record() - - def forward_hook(_module, _input, _output): + _ = backbone(**dummy_inputs_dict) end_event.record() torch.cuda.synchronize() - backbone_times.append(start_event.elapsed_time(end_event)) - - pre_handle = backbone.register_forward_pre_hook(forward_pre_hook) - post_handle = backbone.register_forward_hook(forward_hook) - - try: - print(f"Starting warmup: {num_warmup} runs") - for _ in tqdm(range(num_warmup), desc="Warmup"): - with torch.amp.autocast("cuda", dtype=model_dtype): - _ = pipe( - prompt, - output_type="pil", - num_inference_steps=num_inference_steps, - generator=torch.Generator("cuda").manual_seed(42), - ) - - backbone_times.clear() - - print(f"Starting benchmark: {num_runs} runs") - for _ in tqdm(range(num_runs), desc="Benchmark"): - with torch.amp.autocast("cuda", dtype=model_dtype): - _ = pipe( - prompt, - output_type="pil", - num_inference_steps=num_inference_steps, - generator=torch.Generator("cuda").manual_seed(42), - ) - finally: - pre_handle.remove() - post_handle.remove() - - total_backbone_time = sum(backbone_times) - avg_latency = total_backbone_time / (num_runs * num_inference_steps) - print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms") + times.append(start_event.elapsed_time(end_event)) + + avg_latency = sum(times) / len(times) + times = sorted(times) + p50 = times[len(times) // 2] + p95 = times[int(len(times) * 0.95)] + p99 = times[int(len(times) * 0.99)] + + print(f"\nBackbone-only inference latency:") + print(f" Average: {avg_latency:.2f} ms") + print(f" P50: {p50:.2f} ms") + print(f" P95: {p95:.2f} ms") + print(f" P99: {p99:.2f} ms") + return avg_latency @@ -196,7 +193,12 @@ def main(): pipe.to("cuda") if args.benchmark: - benchmark_model(pipe, args.prompt, model_dtype=model_dtype) + benchmark_backbone_standalone( + pipe, + num_warmup=10, + num_benchmark=100, + model_name=args.model, + ) if not args.skip_image: generate_image(pipe, args.prompt, image_name) From 646458abc8d3d7b840dcb3b1be97ed7d4aa23edc Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 7 Nov 2025 20:04:51 +0000 Subject: [PATCH 2/4] Add cuda profiler Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/diffusers/quantization/diffusion_trt.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index e809099d4..508951565 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -64,7 +64,10 @@ def generate_image(pipe, prompt, image_name): @torch.inference_mode() def benchmark_backbone_standalone( - pipe, num_warmup=10, num_benchmark=100, model_name="flux-dev", + pipe, + num_warmup=10, + num_benchmark=100, + model_name="flux-dev", ): """Benchmark the backbone model directly without running the full pipeline.""" backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet @@ -90,10 +93,12 @@ def benchmark_backbone_standalone( print(f"Benchmarking: {num_benchmark} iterations") times = [] for _ in tqdm(range(num_benchmark), desc="Benchmark"): + torch.cuda.profiler.cudart().cudaProfilerStart() start_event.record() _ = backbone(**dummy_inputs_dict) end_event.record() torch.cuda.synchronize() + torch.cuda.profiler.cudart().cudaProfilerStop() times.append(start_event.elapsed_time(end_event)) avg_latency = sum(times) / len(times) @@ -102,7 +107,7 @@ def benchmark_backbone_standalone( p95 = times[int(len(times) * 0.95)] p99 = times[int(len(times) * 0.99)] - print(f"\nBackbone-only inference latency:") + print("\nBackbone-only inference latency:") print(f" Average: {avg_latency:.2f} ms") print(f" P50: {p50:.2f} ms") print(f" P95: {p95:.2f} ms") From 99ee1966c9f826f85dc030b524a9272f18c02f84 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:08:16 +0000 Subject: [PATCH 3/4] Replace percentile methods Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/diffusers/quantization/diffusion_trt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index 508951565..97efae409 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -15,6 +15,7 @@ import argparse +import numpy as np import torch from onnx_utils.export import ( generate_dummy_inputs_and_dynamic_axes_and_shapes, @@ -102,10 +103,9 @@ def benchmark_backbone_standalone( times.append(start_event.elapsed_time(end_event)) avg_latency = sum(times) / len(times) - times = sorted(times) - p50 = times[len(times) // 2] - p95 = times[int(len(times) * 0.95)] - p99 = times[int(len(times) * 0.99)] + p50 = np.percentile(times, 50) + p95 = np.percentile(times, 95) + p99 = np.percentile(times, 99) print("\nBackbone-only inference latency:") print(f" Average: {avg_latency:.2f} ms") From e4453f73645eef65941765b9fbbeaed6ef7db151 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:32:49 +0000 Subject: [PATCH 4/4] Update the dynamic axes check Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/torch/_deploy/utils/torch_onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 0a16b66f5..cfebd0dc1 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -29,7 +29,6 @@ import torch.nn as nn from onnx import ModelProto from onnxconverter_common import convert_float_to_float16 -from packaging.version import Version from torch.nn.parallel import DataParallel, DistributedDataParallel from modelopt.onnx.autocast.convert import convert_to_f16 @@ -443,7 +442,7 @@ def get_onnx_bytes_and_metadata( ) with torch.inference_mode(), autocast, quantizer_context: additional_kwargs = {} - if not dynamo_export and Version(torch.__version__) >= Version("2.8"): + if not dynamo_export: additional_kwargs["dynamic_axes"] = dynamic_axes torch.onnx.export( model,