Skip to content

Commit e74a468

Browse files
authored
Update benchmarking for diffusers (#487)
## What does this PR do? **Type of change:** Example update **Overview:** - Optimize the benchmarking function in the diffusers example ```python python diffusion_trt.py --model flux-dev --benchmark --model-dtype BFloat16 --skip-image --torch ``` ## Testing ``` Backbone-only inference latency (BFloat16): Average: 139.48 ms P50: 139.36 ms P95: 141.13 ms P99: 141.35 ms ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No --------- Signed-off-by: ajrasane <[email protected]>
1 parent 41f21bc commit e74a468

File tree

2 files changed

+50
-44
lines changed

2 files changed

+50
-44
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import argparse
1717

18+
import numpy as np
1819
import torch
1920
from onnx_utils.export import (
2021
generate_dummy_inputs_and_dynamic_axes_and_shapes,
@@ -49,6 +50,7 @@
4950
}
5051

5152

53+
@torch.inference_mode()
5254
def generate_image(pipe, prompt, image_name):
5355
seed = 42
5456
image = pipe(
@@ -61,56 +63,56 @@ def generate_image(pipe, prompt, image_name):
6163
print(f"Image generated saved as {image_name}")
6264

6365

64-
def benchmark_model(
65-
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype=torch.float16
66+
@torch.inference_mode()
67+
def benchmark_backbone_standalone(
68+
pipe,
69+
num_warmup=10,
70+
num_benchmark=100,
71+
model_name="flux-dev",
6672
):
67-
"""Benchmark the backbone model inference time."""
73+
"""Benchmark the backbone model directly without running the full pipeline."""
6874
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
6975

70-
backbone_times = []
76+
# Generate dummy inputs for the backbone
77+
dummy_inputs, _, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(model_name, backbone)
78+
79+
# Extract the dict from the tuple and move to cuda
80+
dummy_inputs_dict = {
81+
k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs[0].items()
82+
}
83+
84+
# Warmup
85+
print(f"Warming up: {num_warmup} iterations")
86+
for _ in tqdm(range(num_warmup), desc="Warmup"):
87+
_ = backbone(**dummy_inputs_dict)
88+
89+
# Benchmark
90+
torch.cuda.synchronize()
7191
start_event = torch.cuda.Event(enable_timing=True)
7292
end_event = torch.cuda.Event(enable_timing=True)
7393

74-
def forward_pre_hook(_module, _input):
94+
print(f"Benchmarking: {num_benchmark} iterations")
95+
times = []
96+
for _ in tqdm(range(num_benchmark), desc="Benchmark"):
97+
torch.cuda.profiler.cudart().cudaProfilerStart()
7598
start_event.record()
76-
77-
def forward_hook(_module, _input, _output):
99+
_ = backbone(**dummy_inputs_dict)
78100
end_event.record()
79101
torch.cuda.synchronize()
80-
backbone_times.append(start_event.elapsed_time(end_event))
81-
82-
pre_handle = backbone.register_forward_pre_hook(forward_pre_hook)
83-
post_handle = backbone.register_forward_hook(forward_hook)
84-
85-
try:
86-
print(f"Starting warmup: {num_warmup} runs")
87-
for _ in tqdm(range(num_warmup), desc="Warmup"):
88-
with torch.amp.autocast("cuda", dtype=model_dtype):
89-
_ = pipe(
90-
prompt,
91-
output_type="pil",
92-
num_inference_steps=num_inference_steps,
93-
generator=torch.Generator("cuda").manual_seed(42),
94-
)
95-
96-
backbone_times.clear()
97-
98-
print(f"Starting benchmark: {num_runs} runs")
99-
for _ in tqdm(range(num_runs), desc="Benchmark"):
100-
with torch.amp.autocast("cuda", dtype=model_dtype):
101-
_ = pipe(
102-
prompt,
103-
output_type="pil",
104-
num_inference_steps=num_inference_steps,
105-
generator=torch.Generator("cuda").manual_seed(42),
106-
)
107-
finally:
108-
pre_handle.remove()
109-
post_handle.remove()
110-
111-
total_backbone_time = sum(backbone_times)
112-
avg_latency = total_backbone_time / (num_runs * num_inference_steps)
113-
print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms")
102+
torch.cuda.profiler.cudart().cudaProfilerStop()
103+
times.append(start_event.elapsed_time(end_event))
104+
105+
avg_latency = sum(times) / len(times)
106+
p50 = np.percentile(times, 50)
107+
p95 = np.percentile(times, 95)
108+
p99 = np.percentile(times, 99)
109+
110+
print("\nBackbone-only inference latency:")
111+
print(f" Average: {avg_latency:.2f} ms")
112+
print(f" P50: {p50:.2f} ms")
113+
print(f" P95: {p95:.2f} ms")
114+
print(f" P99: {p99:.2f} ms")
115+
114116
return avg_latency
115117

116118

@@ -196,7 +198,12 @@ def main():
196198
pipe.to("cuda")
197199

198200
if args.benchmark:
199-
benchmark_model(pipe, args.prompt, model_dtype=model_dtype)
201+
benchmark_backbone_standalone(
202+
pipe,
203+
num_warmup=10,
204+
num_benchmark=100,
205+
model_name=args.model,
206+
)
200207

201208
if not args.skip_image:
202209
generate_image(pipe, args.prompt, image_name)

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import torch.nn as nn
3030
from onnx import ModelProto
3131
from onnxconverter_common import convert_float_to_float16
32-
from packaging.version import Version
3332
from torch.nn.parallel import DataParallel, DistributedDataParallel
3433

3534
from modelopt.onnx.autocast.convert import convert_to_f16
@@ -443,7 +442,7 @@ def get_onnx_bytes_and_metadata(
443442
)
444443
with torch.inference_mode(), autocast, quantizer_context:
445444
additional_kwargs = {}
446-
if not dynamo_export and Version(torch.__version__) >= Version("2.8"):
445+
if not dynamo_export:
447446
additional_kwargs["dynamic_axes"] = dynamic_axes
448447
torch.onnx.export(
449448
model,

0 commit comments

Comments
 (0)