From d78751e666a0164c767b2dab371e8294b04d7339 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 18 Aug 2025 23:35:01 -0700 Subject: [PATCH 1/5] Changes to TRT-LLM download tool for multigpu distributed case --- .../tensor_parallel_initialize_dist.py | 81 ----- .../tensor_parallel_rotary_embedding.py | 12 +- .../tensor_parallel_simple_example.py | 15 +- .../conversion/custom_ops_converters.py | 1 + .../dynamo/distributed/__init__.py | 0 py/torch_tensorrt/dynamo/distributed/utils.py | 341 ++++++++++++++++++ .../dynamo/distributed/distributed_utils.py | 44 --- 7 files changed, 362 insertions(+), 132 deletions(-) delete mode 100644 examples/distributed_inference/tensor_parallel_initialize_dist.py create mode 100644 py/torch_tensorrt/dynamo/distributed/__init__.py create mode 100644 py/torch_tensorrt/dynamo/distributed/utils.py diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py deleted file mode 100644 index 98d3ca18e9..0000000000 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ /dev/null @@ -1,81 +0,0 @@ -""" -.. _tensor_parallel_initialize_dist: -Tensor Parallel Initialize Distributed Environment -================================================== - -This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. -""" - -import logging -import os -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union - -import numpy as np -import tensorrt as trt -import torch -import torch.distributed as dist -from torch.distributed._tensor.device_mesh import init_device_mesh - - -def find_repo_root(max_depth=10): - dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): - files = os.listdir(dir_path) - if "MODULE.bazel" in files: - return dir_path - else: - dir_path = os.path.dirname(dir_path) - - raise RuntimeError("Could not find repo root") - - -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) - fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) - logger.addHandler(fh) - return logger - - -# This is required for env initialization since we use mpirun -def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): - local_rank = int( - os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) - ) - world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) - - # Set up environment variable to run with mpirun - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - os.environ["TRTLLM_PLUGINS_PATH"] = ( - find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so" - ) - - # Necessary to assign a device to each rank. - torch.cuda.set_device(local_rank) - - # We use nccl backend - dist.init_process_group("nccl") - - # set a manual seed for reproducibility - torch.manual_seed(1111) - - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) - rank = device_mesh.get_rank() - assert rank == local_rank - logger = initialize_logger(rank, logger_file_name) - device_id = ( - rank % torch.cuda.device_count() - ) # Ensure each rank gets a unique device - torch.cuda.set_device(device_id) - - return device_mesh, world_size, rank, logger - - -def cleanup_distributed_env(): - """Clean up distributed process group to prevent resource leaks.""" - if dist.is_initialized(): - dist.destroy_process_group() diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py index da3f3fd8fd..d51f9a5787 100644 --- a/examples/distributed_inference/tensor_parallel_rotary_embedding.py +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -16,15 +16,19 @@ import torch import torch_tensorrt from rotary_embedding import RotaryAttention, parallel_rotary_block -from tensor_parallel_initialize_dist import ( +from torch.distributed import dist +from torch_tensorrt.dynamo.distributed.utils import ( cleanup_distributed_env, + get_tensor_parallel_device_mesh, initialize_distributed_env, + initialize_logger, ) -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_rotary_embedding" -) +if not dist.is_initialized(): + initialize_distributed_env() +device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() +logger = initialize_logger(_rank, "tensor_parallel_simple_example") """ This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index c5688c6e5b..8412cb7fc6 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -36,11 +36,20 @@ RowwiseParallel, parallelize_module, ) - -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_simple_example" +from torch_tensorrt.dynamo.distributed.utils import ( + cleanup_distributed_env, + get_tensor_parallel_device_mesh, + initialize_distributed_env, + initialize_logger, ) +if not dist.is_initialized(): + initialize_distributed_env() + +device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() +logger = initialize_logger(_rank, "tensor_parallel_simple_example") + + """ This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """ diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index db14e3528b..b1c13c105f 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -12,6 +12,7 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) +from torch_tensorrt.dynamo.distributed.utils import load_tensorrt_llm_for_nccl from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_reduce_scatter_op, diff --git a/py/torch_tensorrt/dynamo/distributed/__init__.py b/py/torch_tensorrt/dynamo/distributed/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/dynamo/distributed/utils.py b/py/torch_tensorrt/dynamo/distributed/utils.py new file mode 100644 index 0000000000..099daf7ef7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/distributed/utils.py @@ -0,0 +1,341 @@ +import ctypes +import getpass +import logging +import os +import platform +import tempfile +import urllib.request +from pathlib import Path +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +from torch_tensorrt._version import __tensorrt_llm_version__ + +_WHL_CPYTHON_VERSION = "cp310" + +logger = logging.getLogger(__name__) + + +def initialize_distributed_env( + rank: int = 0, world_size: int = 1, port: int = 29500 +) -> None: + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) + + +def check_tensor_parallel_device_number(world_size: int) -> None: + if world_size % 2 != 0: + raise ValueError( + f"TP examples require even number of GPUs, but got {world_size} gpus" + ) + + +def get_tensor_parallel_device_mesh( + rank: int = 0, world_size: int = 1 +) -> tuple[DeviceMesh, int, int]: + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank + + +def initialize_logger(rank: int, logger_file_name: str) -> logging.Logger: + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger + + +def cleanup_distributed_env() -> None: + """Clean up distributed process group to prevent resource leaks.""" + if dist.is_initialized(): + dist.destroy_process_group() + + +def is_platform_supported_for_trtllm() -> bool: + """ + Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend. + + Returns: + bool: True if supported, False otherwise. + + Unsupported: + - Windows platforms + - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release) + - CUDA 13 not supported + """ + system = platform.system().lower() + machine = platform.machine().lower() + release = platform.release().lower() + + if "windows" in system: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Windows." + ) + return False + + if machine == "aarch64" and "tegra" in release: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices." + ) + return False + + try: + cuda_version = torch.version.cuda # e.g., "12.4" or "13.0" + if cuda_version is None: + logger.warning("No CUDA runtime detected — TRT-LLM plugins unavailable.") + return False + + major, minor = map(int, cuda_version.split(".")) + if major != 12: + logger.warning("CUDA 13 is not supported for TRT-LLM plugins.") + return False + + return True + + except Exception as e: + logger.warning(f"Failed to detect CUDA version: {e}") + return False + + + return True + + +def _cache_root() -> Path: + username = getpass.getuser() + return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" + + +def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: + return ( + _cache_root() + / "trtllm" + / f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}" + ) + + +def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: + from torch.distributed import barrier, get_rank, is_initialized + + if not is_initialized(): + # Single process case, just unzip + is_master = True + else: + is_master = get_rank() == 0 # only rank 0 does the unzip + + if is_master: + try: + import zipfile + except ImportError as e: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + try: + with zipfile.ZipFile(wheel_path) as zip_ref: + zip_ref.extractall(extract_dir) + logger.debug(f"Extracted wheel to {extract_dir}") + + except FileNotFoundError as e: + # This should capture the errors in the download failure above + logger.error(f"Wheel file not found at {wheel_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {wheel_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e + + # Make sure others wait until unzip is done + if is_initialized(): + barrier() + + +def download_and_get_plugin_lib_path() -> Optional[str]: + """ + Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. + + Args: + platform (str): Platform identifier (e.g., 'linux_x86_64') + + Returns: + Optional[str]: Path to shared library or None if operation fails. + """ + platform_system = platform.system().lower() + platform_machine = platform.machine().lower() + wheel_filename = ( + f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-" + f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl" + ) + wheel_path = _cache_root() / wheel_filename + extract_dir = _extracted_dir_trtllm(platform_system, platform_machine) + # else will never be met though + lib_filename = ( + "libnvinfer_plugin_tensorrt_llm.so" + if "linux" in platform_system + else "libnvinfer_plugin_tensorrt_llm.dll" + ) + # eg: /tmp/torch_tensorrt_/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so + plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename + + if plugin_lib_path.exists(): + return str(plugin_lib_path) + + wheel_path.parent.mkdir(parents=True, exist_ok=True) + extract_dir.mkdir(parents=True, exist_ok=True) + + if not wheel_path.exists(): + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + download_url = base_url + wheel_filename + try: + logger.debug(f"Downloading {download_url} ...") + urllib.request.urlretrieve(download_url, wheel_path) + logger.debug("Download succeeded and TRT-LLM wheel is now present") + except urllib.error.HTTPError as e: + logger.error( + f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" + ) + except urllib.error.URLError as e: + logger.error( + f"URL error when trying to download {download_url}: {e.reason}" + ) + except OSError as e: + logger.error(f"Local file write error: {e}") + + extract_wheel_file(wheel_path, extract_dir) + + try: + wheel_path.unlink(missing_ok=True) + logger.debug(f"Deleted wheel file: {wheel_path}") + except Exception as e: + logger.warning(f"Could not delete wheel file {wheel_path}: {e}") + if not plugin_lib_path.exists(): + logger.error( + f"Plugin library not found at expected location: {plugin_lib_path}" + ) + return None + + return str(plugin_lib_path) + + +def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: + """ + Loads and initializes the TensorRT-LLM plugin from the given shared library path. + + Args: + plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. + + Returns: + bool: True if successful, False otherwise. + """ + try: + handle = ctypes.CDLL(plugin_lib_path) + logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + if "libmpi" in str(e_os_error): + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", + exc_info=e_os_error, + ) + else: + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"Ensure the path is correct and the library is compatible.", + exc_info=e_os_error, + ) + return False + + try: + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + logger.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False + + try: + if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): + logger.info("TensorRT-LLM plugin successfully initialized") + return True + else: + logger.warning("TensorRT-LLM plugin library failed in initialization") + return False + except Exception as e_initialization_error: + logger.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False + return False + + +def load_tensorrt_llm_for_nccl() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + Either the env variable TRTLLM_PLUGINS_PATH can specify the path + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + if not is_platform_supported_for_trtllm(): + return False + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + + if plugin_lib_path: + return load_and_initialize_trtllm_plugin(plugin_lib_path) + else: + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: + logger.warning( + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" + ) + return False + + plugin_lib_path = download_and_get_plugin_lib_path() + return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type] + return False diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py index bc058aaaec..a2661c22d7 100644 --- a/tests/py/dynamo/distributed/distributed_utils.py +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -13,47 +13,3 @@ def set_environment_variables_pytest(): os.environ["RANK"] = str(0) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(29500) - - -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) - fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) - logger.addHandler(fh) - return logger - - -# This is required for env initialization since we use mpirun -def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): - local_rank = int( - os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) - ) - world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) - - # Set up environment variable to run with mpirun - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so" - - # Necessary to assign a device to each rank. - torch.cuda.set_device(local_rank) - - # We use nccl backend - dist.init_process_group("nccl") - - # set a manual seed for reproducibility - torch.manual_seed(1111) - - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) - rank = device_mesh.get_rank() - assert rank == local_rank - logger = initialize_logger(rank, logger_file_name) - device_id = ( - rank % torch.cuda.device_count() - ) # Ensure each rank gets a unique device - torch.cuda.set_device(device_id) - - return device_mesh, world_size, rank, logger From 56b80db08ad565a013f1805a9359a774f1c66703 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 25 Sep 2025 12:14:31 -0700 Subject: [PATCH 2/5] Distributed utils package, separating out env for single GPU and multiGPU --- .../tensor_parallel_initialize_dist.py | 55 +++++++++++++++++++ .../tensor_parallel_rotary_embedding.py | 18 +++--- .../tensor_parallel_simple_example.py | 10 ++-- .../dynamo/distributed/__init__.py | 1 + py/torch_tensorrt/dynamo/distributed/utils.py | 32 ----------- setup.py | 1 + .../dynamo/distributed/distributed_utils.py | 32 ++++++++++- tests/py/dynamo/distributed/test_nccl_ops.py | 25 +++++++-- 8 files changed, 121 insertions(+), 53 deletions(-) create mode 100644 examples/distributed_inference/tensor_parallel_initialize_dist.py diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py new file mode 100644 index 0000000000..068316659e --- /dev/null +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -0,0 +1,55 @@ +""" +.. _tensor_parallel_initialize_dist: +Tensor Parallel Initialize Distributed Environment +================================================== + +This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. +""" + +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import torch.distributed as dist +from torch.distributed._tensor.device_mesh import init_device_mesh + + +def initialize_distributed_env(rank=0, world_size=1, port=29500): + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) + + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank + + +def cleanup_distributed_env(): + """Clean up distributed process group to prevent resource leaks.""" + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py index d51f9a5787..2f3de7d4e2 100644 --- a/examples/distributed_inference/tensor_parallel_rotary_embedding.py +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -14,21 +14,25 @@ import time import torch -import torch_tensorrt -from rotary_embedding import RotaryAttention, parallel_rotary_block -from torch.distributed import dist -from torch_tensorrt.dynamo.distributed.utils import ( +import torch.distributed as dist +from tensor_parallel_initialize_dist import ( cleanup_distributed_env, - get_tensor_parallel_device_mesh, initialize_distributed_env, - initialize_logger, ) if not dist.is_initialized(): initialize_distributed_env() +import torch_tensorrt +from torch_tensorrt.dynamo.distributed.utils import ( + get_tensor_parallel_device_mesh, + initialize_logger, +) + device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() -logger = initialize_logger(_rank, "tensor_parallel_simple_example") +logger = initialize_logger(_rank, "tensor_parallel_rotary_embedding") + +from rotary_embedding import RotaryAttention, parallel_rotary_block """ This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 8412cb7fc6..ca0ecaf9a1 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -25,11 +25,14 @@ import torch import torch.distributed as dist import torch.nn as nn -import torch_tensorrt from tensor_parallel_initialize_dist import ( cleanup_distributed_env, initialize_distributed_env, ) + +if not dist.is_initialized(): + initialize_distributed_env() +import torch_tensorrt from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -37,15 +40,10 @@ parallelize_module, ) from torch_tensorrt.dynamo.distributed.utils import ( - cleanup_distributed_env, get_tensor_parallel_device_mesh, - initialize_distributed_env, initialize_logger, ) -if not dist.is_initialized(): - initialize_distributed_env() - device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() logger = initialize_logger(_rank, "tensor_parallel_simple_example") diff --git a/py/torch_tensorrt/dynamo/distributed/__init__.py b/py/torch_tensorrt/dynamo/distributed/__init__.py index e69de29bb2..8b13789179 100644 --- a/py/torch_tensorrt/dynamo/distributed/__init__.py +++ b/py/torch_tensorrt/dynamo/distributed/__init__.py @@ -0,0 +1 @@ + diff --git a/py/torch_tensorrt/dynamo/distributed/utils.py b/py/torch_tensorrt/dynamo/distributed/utils.py index 099daf7ef7..ad217a09af 100644 --- a/py/torch_tensorrt/dynamo/distributed/utils.py +++ b/py/torch_tensorrt/dynamo/distributed/utils.py @@ -9,7 +9,6 @@ from typing import Optional import torch -import torch.distributed as dist from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh from torch_tensorrt._version import __tensorrt_llm_version__ @@ -18,30 +17,6 @@ logger = logging.getLogger(__name__) -def initialize_distributed_env( - rank: int = 0, world_size: int = 1, port: int = 29500 -) -> None: - local_rank = int( - os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) - ) - world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) - - # Set up environment variable to run with mpirun - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - - # Necessary to assign a device to each rank. - torch.cuda.set_device(local_rank) - - # We use nccl backend - dist.init_process_group("nccl") - - # set a manual seed for reproducibility - torch.manual_seed(1111) - - def check_tensor_parallel_device_number(world_size: int) -> None: if world_size % 2 != 0: raise ValueError( @@ -76,12 +51,6 @@ def initialize_logger(rank: int, logger_file_name: str) -> logging.Logger: return logger -def cleanup_distributed_env() -> None: - """Clean up distributed process group to prevent resource leaks.""" - if dist.is_initialized(): - dist.destroy_process_group() - - def is_platform_supported_for_trtllm() -> bool: """ Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend. @@ -127,7 +96,6 @@ def is_platform_supported_for_trtllm() -> bool: logger.warning(f"Failed to detect CUDA version: {e}") return False - return True diff --git a/setup.py b/setup.py index d487530626..34cb95ffd9 100644 --- a/setup.py +++ b/setup.py @@ -450,6 +450,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary", "torch_tensorrt.dynamo.conversion.plugins", "torch_tensorrt.dynamo.debug", + "torch_tensorrt.dynamo.distributed", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py index a2661c22d7..b13a07d308 100644 --- a/tests/py/dynamo/distributed/distributed_utils.py +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -1,5 +1,6 @@ import logging import os +import random import numpy as np import tensorrt as trt @@ -8,8 +9,35 @@ from torch.distributed._tensor.device_mesh import init_device_mesh -def set_environment_variables_pytest(): +def set_environment_variables_pytest_single_process(): + port = 29500 + random.randint(1, 1000) os.environ["WORLD_SIZE"] = str(1) os.environ["RANK"] = str(0) os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(29500) + os.environ["MASTER_PORT"] = str(port) + + +def set_environment_variables_pytest_multi_process( + rank: int = 0, world_size: int = 1 +) -> None: + port = 29500 + random.randint(1, 1000) + # these variables are set by mpirun -n 2 + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index eafe16d455..c6e803ca1c 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -5,11 +5,26 @@ import torch.distributed as dist import torch.nn as nn from conversion.harness import DispatchTestCase -from distributed_utils import set_environment_variables_pytest + +# The distributed env initialization has to be before torchTRT import since it uses barrier +from distributed_utils import ( + set_environment_variables_pytest_multi_process, + set_environment_variables_pytest_single_process, +) from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt._utils import is_platform_supported_for_trtllm +if "OMPI_COMM_WORLD_SIZE" in os.environ: + set_environment_variables_pytest_multi_process() +else: + set_environment_variables_pytest_single_process() + +if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", + init_method="env://", + ) class DistributedGatherModel(nn.Module): def __init__(self, input_dim, world_size, group_name): @@ -48,11 +63,9 @@ class TestNcclOpsConverter(DispatchTestCase): ) @classmethod def setUpClass(cls): - set_environment_variables_pytest() - cls.world_size = 1 - if not dist.is_initialized(): - dist.init_process_group(backend="nccl") - cls.group = dist.new_group(ranks=[0]) + cls.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + cls.rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) + cls.group = dist.new_group(ranks=list(range(cls.world_size))) cls.group_name = cls.group.group_name @classmethod From c7bf85283e0a3b12b7247d0262478c843384f5b4 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 13 Oct 2025 11:41:51 -0700 Subject: [PATCH 3/5] changes to account for the base branch change --- .../tensor_parallel_initialize_dist.py | 1 + .../tensor_parallel_rotary_embedding.py | 5 +- .../tensor_parallel_simple_example.py | 4 +- .../conversion/custom_ops_converters.py | 1 - py/torch_tensorrt/dynamo/distributed/utils.py | 270 +----------------- .../dynamo/distributed/distributed_utils.py | 2 + tests/py/dynamo/distributed/test_nccl_ops.py | 3 +- 7 files changed, 11 insertions(+), 275 deletions(-) diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 068316659e..9aa715ae35 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -17,6 +17,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh +# this is kept at the application level, when mpirun is used to run the application def initialize_distributed_env(rank=0, world_size=1, port=29500): local_rank = int( os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py index 2f3de7d4e2..7b0c9ef2e5 100644 --- a/examples/distributed_inference/tensor_parallel_rotary_embedding.py +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -26,17 +26,18 @@ import torch_tensorrt from torch_tensorrt.dynamo.distributed.utils import ( get_tensor_parallel_device_mesh, - initialize_logger, + initialize_distributed_logger, ) device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() -logger = initialize_logger(_rank, "tensor_parallel_rotary_embedding") +logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding") from rotary_embedding import RotaryAttention, parallel_rotary_block """ This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py +Command to run with 2 GPUs: mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py """ BATCH = 2 diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index ca0ecaf9a1..1f4b869ece 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -41,11 +41,11 @@ ) from torch_tensorrt.dynamo.distributed.utils import ( get_tensor_parallel_device_mesh, - initialize_logger, + initialize_distributed_logger, ) device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() -logger = initialize_logger(_rank, "tensor_parallel_simple_example") +logger = initialize_distributed_logger(_rank, "tensor_parallel_simple_example") """ diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index b1c13c105f..db14e3528b 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -12,7 +12,6 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) -from torch_tensorrt.dynamo.distributed.utils import load_tensorrt_llm_for_nccl from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_reduce_scatter_op, diff --git a/py/torch_tensorrt/dynamo/distributed/utils.py b/py/torch_tensorrt/dynamo/distributed/utils.py index ad217a09af..e835b91439 100644 --- a/py/torch_tensorrt/dynamo/distributed/utils.py +++ b/py/torch_tensorrt/dynamo/distributed/utils.py @@ -1,18 +1,8 @@ -import ctypes -import getpass import logging import os -import platform -import tempfile -import urllib.request -from pathlib import Path -from typing import Optional import torch from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh -from torch_tensorrt._version import __tensorrt_llm_version__ - -_WHL_CPYTHON_VERSION = "cp310" logger = logging.getLogger(__name__) @@ -42,268 +32,10 @@ def get_tensor_parallel_device_mesh( return device_mesh, world_size, rank -def initialize_logger(rank: int, logger_file_name: str) -> logging.Logger: +def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger: logger = logging.getLogger() logger.setLevel(logging.INFO) fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") fh.setLevel(logging.INFO) logger.addHandler(fh) return logger - - -def is_platform_supported_for_trtllm() -> bool: - """ - Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend. - - Returns: - bool: True if supported, False otherwise. - - Unsupported: - - Windows platforms - - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release) - - CUDA 13 not supported - """ - system = platform.system().lower() - machine = platform.machine().lower() - release = platform.release().lower() - - if "windows" in system: - logger.info( - "TensorRT-LLM plugins for NCCL backend are not supported on Windows." - ) - return False - - if machine == "aarch64" and "tegra" in release: - logger.info( - "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices." - ) - return False - - try: - cuda_version = torch.version.cuda # e.g., "12.4" or "13.0" - if cuda_version is None: - logger.warning("No CUDA runtime detected — TRT-LLM plugins unavailable.") - return False - - major, minor = map(int, cuda_version.split(".")) - if major != 12: - logger.warning("CUDA 13 is not supported for TRT-LLM plugins.") - return False - - return True - - except Exception as e: - logger.warning(f"Failed to detect CUDA version: {e}") - return False - - return True - - -def _cache_root() -> Path: - username = getpass.getuser() - return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" - - -def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: - return ( - _cache_root() - / "trtllm" - / f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}" - ) - - -def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: - from torch.distributed import barrier, get_rank, is_initialized - - if not is_initialized(): - # Single process case, just unzip - is_master = True - else: - is_master = get_rank() == 0 # only rank 0 does the unzip - - if is_master: - try: - import zipfile - except ImportError as e: - raise ImportError( - "zipfile module is required but not found. Please install zipfile" - ) - try: - with zipfile.ZipFile(wheel_path) as zip_ref: - zip_ref.extractall(extract_dir) - logger.debug(f"Extracted wheel to {extract_dir}") - - except FileNotFoundError as e: - # This should capture the errors in the download failure above - logger.error(f"Wheel file not found at {wheel_path}: {e}") - raise RuntimeError( - f"Failed to find downloaded wheel file at {wheel_path}" - ) from e - except zipfile.BadZipFile as e: - logger.error(f"Invalid or corrupted wheel file: {e}") - raise RuntimeError( - "Downloaded wheel file is corrupted or not a valid zip archive" - ) from e - except Exception as e: - logger.error(f"Unexpected error while extracting wheel: {e}") - raise RuntimeError( - "Unexpected error during extraction of TensorRT-LLM wheel" - ) from e - - # Make sure others wait until unzip is done - if is_initialized(): - barrier() - - -def download_and_get_plugin_lib_path() -> Optional[str]: - """ - Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. - - Args: - platform (str): Platform identifier (e.g., 'linux_x86_64') - - Returns: - Optional[str]: Path to shared library or None if operation fails. - """ - platform_system = platform.system().lower() - platform_machine = platform.machine().lower() - wheel_filename = ( - f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-" - f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl" - ) - wheel_path = _cache_root() / wheel_filename - extract_dir = _extracted_dir_trtllm(platform_system, platform_machine) - # else will never be met though - lib_filename = ( - "libnvinfer_plugin_tensorrt_llm.so" - if "linux" in platform_system - else "libnvinfer_plugin_tensorrt_llm.dll" - ) - # eg: /tmp/torch_tensorrt_/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so - plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename - - if plugin_lib_path.exists(): - return str(plugin_lib_path) - - wheel_path.parent.mkdir(parents=True, exist_ok=True) - extract_dir.mkdir(parents=True, exist_ok=True) - - if not wheel_path.exists(): - base_url = "https://pypi.nvidia.com/tensorrt-llm/" - download_url = base_url + wheel_filename - try: - logger.debug(f"Downloading {download_url} ...") - urllib.request.urlretrieve(download_url, wheel_path) - logger.debug("Download succeeded and TRT-LLM wheel is now present") - except urllib.error.HTTPError as e: - logger.error( - f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" - ) - except urllib.error.URLError as e: - logger.error( - f"URL error when trying to download {download_url}: {e.reason}" - ) - except OSError as e: - logger.error(f"Local file write error: {e}") - - extract_wheel_file(wheel_path, extract_dir) - - try: - wheel_path.unlink(missing_ok=True) - logger.debug(f"Deleted wheel file: {wheel_path}") - except Exception as e: - logger.warning(f"Could not delete wheel file {wheel_path}: {e}") - if not plugin_lib_path.exists(): - logger.error( - f"Plugin library not found at expected location: {plugin_lib_path}" - ) - return None - - return str(plugin_lib_path) - - -def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: - """ - Loads and initializes the TensorRT-LLM plugin from the given shared library path. - - Args: - plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. - - Returns: - bool: True if successful, False otherwise. - """ - try: - handle = ctypes.CDLL(plugin_lib_path) - logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") - except OSError as e_os_error: - if "libmpi" in str(e_os_error): - logger.warning( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", - exc_info=e_os_error, - ) - else: - logger.warning( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " - f"Ensure the path is correct and the library is compatible.", - exc_info=e_os_error, - ) - return False - - try: - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - handle.initTrtLlmPlugins.restype = ctypes.c_bool - except AttributeError as e_plugin_unavailable: - logger.warning( - "Unable to initialize the TensorRT-LLM plugin library", - exc_info=e_plugin_unavailable, - ) - return False - - try: - if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): - logger.info("TensorRT-LLM plugin successfully initialized") - return True - else: - logger.warning("TensorRT-LLM plugin library failed in initialization") - return False - except Exception as e_initialization_error: - logger.warning( - "Exception occurred during TensorRT-LLM plugin library initialization", - exc_info=e_initialization_error, - ) - return False - return False - - -def load_tensorrt_llm_for_nccl() -> bool: - """ - Attempts to load the TensorRT-LLM plugin and initialize it. - Either the env variable TRTLLM_PLUGINS_PATH can specify the path - Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it - - Returns: - bool: True if the plugin was successfully loaded and initialized, False otherwise. - """ - if not is_platform_supported_for_trtllm(): - return False - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - - if plugin_lib_path: - return load_and_initialize_trtllm_plugin(plugin_lib_path) - else: - # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user - use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( - "1", - "true", - "yes", - "on", - ) - if not use_trtllm_plugin: - logger.warning( - "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" - ) - return False - - plugin_lib_path = download_and_get_plugin_lib_path() - return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type] - return False diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py index b13a07d308..6d13ecb1a1 100644 --- a/tests/py/dynamo/distributed/distributed_utils.py +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -9,6 +9,8 @@ from torch.distributed._tensor.device_mesh import init_device_mesh +# the below two functions are used to set the environment variables for the pytest single and multi process +# this is for the github CI where we use pytest def set_environment_variables_pytest_single_process(): port = 29500 + random.randint(1, 1000) os.environ["WORLD_SIZE"] = str(1) diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index c6e803ca1c..652bbe49e2 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -6,7 +6,7 @@ import torch.nn as nn from conversion.harness import DispatchTestCase -# The distributed env initialization has to be before torchTRT import since it uses barrier +# The distributed env initialization has to be before import of torchTRT, since it uses barrier for installation from distributed_utils import ( set_environment_variables_pytest_multi_process, set_environment_variables_pytest_single_process, @@ -26,6 +26,7 @@ init_method="env://", ) + class DistributedGatherModel(nn.Module): def __init__(self, input_dim, world_size, group_name): super().__init__() From 38224c5932075379ff328a0cbc6382a042679529 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 13 Oct 2025 13:04:06 -0700 Subject: [PATCH 4/5] the barrier for TRT-LLM installation --- py/torch_tensorrt/_utils.py | 76 ++++++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index f59dce9b1c..523eccc968 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -143,13 +143,55 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: ) +def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: + # this will not be encountered in case of platforms not supporting torch distributed/nccl/TRT-LLM + from torch.distributed import barrier, get_rank, is_initialized + + if not is_initialized(): + # Single process case, just unzip + is_master = True + else: + is_master = get_rank() == 0 # only rank 0 does the unzip + + if is_master: + try: + import zipfile + except ImportError as e: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + try: + with zipfile.ZipFile(wheel_path) as zip_ref: + zip_ref.extractall(extract_dir) + logger.debug(f"Extracted wheel to {extract_dir}") + + except FileNotFoundError as e: + # This should capture the errors in the download failure above + logger.error(f"Wheel file not found at {wheel_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {wheel_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e + + # Make sure others wait until unzip is done + if is_initialized(): + barrier() + + def download_and_get_plugin_lib_path() -> Optional[str]: """ Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. - Args: platform (str): Platform identifier (e.g., 'linux_x86_64') - Returns: Optional[str]: Path to shared library or None if operation fails. """ @@ -194,32 +236,7 @@ def download_and_get_plugin_lib_path() -> Optional[str]: except OSError as e: logger.error(f"Local file write error: {e}") - try: - import zipfile - except ImportError as e: - raise ImportError( - "zipfile module is required but not found. Please install zipfile" - ) - try: - with zipfile.ZipFile(wheel_path) as zip_ref: - zip_ref.extractall(extract_dir) - logger.debug(f"Extracted wheel to {extract_dir}") - except FileNotFoundError as e: - # This should capture the errors in the download failure above - logger.error(f"Wheel file not found at {wheel_path}: {e}") - raise RuntimeError( - f"Failed to find downloaded wheel file at {wheel_path}" - ) from e - except zipfile.BadZipFile as e: - logger.error(f"Invalid or corrupted wheel file: {e}") - raise RuntimeError( - "Downloaded wheel file is corrupted or not a valid zip archive" - ) from e - except Exception as e: - logger.error(f"Unexpected error while extracting wheel: {e}") - raise RuntimeError( - "Unexpected error during extraction of TensorRT-LLM wheel" - ) from e + extract_wheel_file(wheel_path, extract_dir) try: wheel_path.unlink(missing_ok=True) @@ -238,10 +255,8 @@ def download_and_get_plugin_lib_path() -> Optional[str]: def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: """ Loads and initializes the TensorRT-LLM plugin from the given shared library path. - Args: plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. - Returns: bool: True if successful, False otherwise. """ @@ -293,7 +308,6 @@ def load_tensorrt_llm_for_nccl() -> bool: Attempts to load the TensorRT-LLM plugin and initialize it. Either the env variable TRTLLM_PLUGINS_PATH can specify the path Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it - Returns: bool: True if the plugin was successfully loaded and initialized, False otherwise. """ From f07b5cb63b6083d45fd02731b789f4b1294b3546 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 17 Oct 2025 15:50:28 -0700 Subject: [PATCH 5/5] changing the implementation of avoiding race condition in unzip of TRT-LLM wheel by using lock file --- .../tensor_parallel_initialize_dist.py | 38 +++++++++++++- .../tensor_parallel_rotary_embedding.py | 12 ++--- .../tensor_parallel_simple_example.py | 4 +- py/torch_tensorrt/_utils.py | 50 +++++++++++-------- .../dynamo/distributed/__init__.py | 1 - py/torch_tensorrt/dynamo/distributed/utils.py | 41 --------------- setup.py | 1 - tests/py/dynamo/distributed/test_nccl_ops.py | 1 - 8 files changed, 74 insertions(+), 74 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/distributed/__init__.py delete mode 100644 py/torch_tensorrt/dynamo/distributed/utils.py diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 9aa715ae35..5fffb3fa00 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -14,7 +14,9 @@ import tensorrt as trt import torch import torch.distributed as dist -from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh + +logger = logging.getLogger(__name__) # this is kept at the application level, when mpirun is used to run the application @@ -54,3 +56,37 @@ def cleanup_distributed_env(): """Clean up distributed process group to prevent resource leaks.""" if dist.is_initialized(): dist.destroy_process_group() + + +def check_tensor_parallel_device_number(world_size: int) -> None: + if world_size % 2 != 0: + raise ValueError( + f"TP examples require even number of GPUs, but got {world_size} gpus" + ) + + +def get_tensor_parallel_device_mesh( + rank: int = 0, world_size: int = 1 +) -> tuple[DeviceMesh, int, int]: + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank + + +def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger: + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py index 7b0c9ef2e5..7a55497703 100644 --- a/examples/distributed_inference/tensor_parallel_rotary_embedding.py +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -9,25 +9,21 @@ """ -import logging -import os import time import torch import torch.distributed as dist from tensor_parallel_initialize_dist import ( cleanup_distributed_env, + get_tensor_parallel_device_mesh, initialize_distributed_env, + initialize_distributed_logger, ) if not dist.is_initialized(): initialize_distributed_env() import torch_tensorrt -from torch_tensorrt.dynamo.distributed.utils import ( - get_tensor_parallel_device_mesh, - initialize_distributed_logger, -) device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding") @@ -36,8 +32,8 @@ """ This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning -Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py -Command to run with 2 GPUs: mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py +Command to run with single GPU: USE_TRTLLM_PLUGINS=1 mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py +Command to run with 2 GPUs: USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py """ BATCH = 2 diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 1f4b869ece..bf0c13560f 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -16,7 +16,7 @@ ----- .. code-block:: bash - mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py + USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py """ import time @@ -27,7 +27,9 @@ import torch.nn as nn from tensor_parallel_initialize_dist import ( cleanup_distributed_env, + get_tensor_parallel_device_mesh, initialize_distributed_env, + initialize_distributed_logger, ) if not dist.is_initialized(): diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index 523eccc968..a259a54997 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -5,6 +5,7 @@ import platform import sys import tempfile +import time import urllib.request from pathlib import Path from typing import Any, Optional @@ -144,47 +145,57 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: - # this will not be encountered in case of platforms not supporting torch distributed/nccl/TRT-LLM - from torch.distributed import barrier, get_rank, is_initialized - - if not is_initialized(): - # Single process case, just unzip - is_master = True - else: - is_master = get_rank() == 0 # only rank 0 does the unzip - - if is_master: + """ + Safely extract a wheel file to a directory with a lock to prevent concurrent extraction. + """ + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) # MPI rank from OpenMPI + torch.cuda.set_device(rank) + lock_file = extract_dir / ".extracting" + + # Rank 0 performs extraction + if rank == 0: + logger.debug( + f"[Rank {rank}] Starting extraction of {wheel_path} to {extract_dir}" + ) try: import zipfile except ImportError as e: raise ImportError( "zipfile module is required but not found. Please install zipfile" ) + # Create lock file to signal extraction in progress + extract_dir.mkdir(parents=True, exist_ok=False) + lock_file.touch(exist_ok=False) try: with zipfile.ZipFile(wheel_path) as zip_ref: zip_ref.extractall(extract_dir) - logger.debug(f"Extracted wheel to {extract_dir}") - + logger.debug(f"[Rank {rank}] Extraction complete: {extract_dir}") except FileNotFoundError as e: - # This should capture the errors in the download failure above - logger.error(f"Wheel file not found at {wheel_path}: {e}") + logger.error(f"[Rank {rank}] Wheel file not found at {wheel_path}: {e}") raise RuntimeError( f"Failed to find downloaded wheel file at {wheel_path}" ) from e except zipfile.BadZipFile as e: - logger.error(f"Invalid or corrupted wheel file: {e}") + logger.error(f"[Rank {rank}] Invalid or corrupted wheel file: {e}") raise RuntimeError( "Downloaded wheel file is corrupted or not a valid zip archive" ) from e except Exception as e: - logger.error(f"Unexpected error while extracting wheel: {e}") + logger.error(f"[Rank {rank}] Unexpected error while extracting wheel: {e}") raise RuntimeError( "Unexpected error during extraction of TensorRT-LLM wheel" ) from e + finally: + # Remove lock file to signal completion + lock_file.unlink(missing_ok=True) - # Make sure others wait until unzip is done - if is_initialized(): - barrier() + else: + # Other ranks wait for extraction to complete + while lock_file.exists(): + logger.debug( + f"[Rank {rank}] Waiting for extraction to finish at {extract_dir}..." + ) + time.sleep(0.5) def download_and_get_plugin_lib_path() -> Optional[str]: @@ -216,7 +227,6 @@ def download_and_get_plugin_lib_path() -> Optional[str]: return str(plugin_lib_path) wheel_path.parent.mkdir(parents=True, exist_ok=True) - extract_dir.mkdir(parents=True, exist_ok=True) if not wheel_path.exists(): base_url = "https://pypi.nvidia.com/tensorrt-llm/" diff --git a/py/torch_tensorrt/dynamo/distributed/__init__.py b/py/torch_tensorrt/dynamo/distributed/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/py/torch_tensorrt/dynamo/distributed/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/py/torch_tensorrt/dynamo/distributed/utils.py b/py/torch_tensorrt/dynamo/distributed/utils.py deleted file mode 100644 index e835b91439..0000000000 --- a/py/torch_tensorrt/dynamo/distributed/utils.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging -import os - -import torch -from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh - -logger = logging.getLogger(__name__) - - -def check_tensor_parallel_device_number(world_size: int) -> None: - if world_size % 2 != 0: - raise ValueError( - f"TP examples require even number of GPUs, but got {world_size} gpus" - ) - - -def get_tensor_parallel_device_mesh( - rank: int = 0, world_size: int = 1 -) -> tuple[DeviceMesh, int, int]: - local_rank = int( - os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) - ) - world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) - rank = device_mesh.get_rank() - assert rank == local_rank - device_id = ( - rank % torch.cuda.device_count() - ) # Ensure each rank gets a unique device - torch.cuda.set_device(device_id) - - return device_mesh, world_size, rank - - -def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger: - logger = logging.getLogger() - logger.setLevel(logging.INFO) - fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) - logger.addHandler(fh) - return logger diff --git a/setup.py b/setup.py index 34cb95ffd9..d487530626 100644 --- a/setup.py +++ b/setup.py @@ -450,7 +450,6 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary", "torch_tensorrt.dynamo.conversion.plugins", "torch_tensorrt.dynamo.debug", - "torch_tensorrt.dynamo.distributed", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index 652bbe49e2..d239179d23 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -6,7 +6,6 @@ import torch.nn as nn from conversion.harness import DispatchTestCase -# The distributed env initialization has to be before import of torchTRT, since it uses barrier for installation from distributed_utils import ( set_environment_variables_pytest_multi_process, set_environment_variables_pytest_single_process,