diff --git a/.github/workflows/build-test-linux-aarch64.yml b/.github/workflows/build-test-linux-aarch64.yml index 2604d18f92..9029ae9dca 100644 --- a/.github/workflows/build-test-linux-aarch64.yml +++ b/.github/workflows/build-test-linux-aarch64.yml @@ -356,6 +356,41 @@ jobs: python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . popd + tests-py-distributed: + name: Test dynamo distributed [Python] + needs: [filter-matrix, build] + if: false + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh + uses: ./.github/workflows/linux-test.yml + with: + job-name: tests-py-dynamo-distributed + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.filter-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + set -euo pipefail + export USE_HOST_DEPS=1 + export CI_BUILD=1 + export USE_TRTLLM_PLUGINS=1 + dnf install -y mpich mpich-devel openmpi openmpi-devel + pushd . + cd tests/py + cd dynamo + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_distributed_test_results.xml distributed/test_nccl_ops.py + popd + + concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} cancel-in-progress: true \ No newline at end of file diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index b1630c03be..1f099bca1e 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -340,6 +340,39 @@ jobs: python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml . popd + tests-py-distributed: + name: Test dynamo distributed [Python] + needs: [filter-matrix, build] + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/smoke_test_script.sh + uses: ./.github/workflows/linux-test.yml + with: + job-name: tests-py-dynamo-distributed + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.filter-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + set -euo pipefail + export USE_HOST_DEPS=1 + export CI_BUILD=1 + export USE_TRTLLM_PLUGINS=1 + dnf install -y mpich mpich-devel openmpi openmpi-devel + pushd . + cd tests/py + cd dynamo + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_distributed_test_results.xml distributed/test_nccl_ops.py + popd + concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-tensorrt-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} cancel-in-progress: true diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index 113fe23de6..10f4120a47 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,3 +1,4 @@ __cuda_version__: "12.8" __tensorrt_version__: "10.12.0" __tensorrt_rtx_version__: "1.0.0" +__tensorrt_llm_version__: "0.17.0.post1" diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py new file mode 100644 index 0000000000..dddee63871 --- /dev/null +++ b/examples/distributed_inference/llama3_model.py @@ -0,0 +1,496 @@ +# Taken and modified pytorch lightening +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning + + +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, +) + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + device: str = "cuda" + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Reshape frequency tensor for broadcasting it with another tensor. + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary embeddings to input tensors using the given frequency tensor. + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class RMSNorm(nn.Module): + """Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class Attention(nn.Module): + """Multi-head attention module. + Args: + model_args (ModelArgs): Model configuration arguments. + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + + def init_weights(self, init_std: float) -> None: + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> Any: + """Forward pass of the attention module. + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + Returns: + torch.Tensor: Output tensor after attention. + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """FeedForward module. + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x) -> Any: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float) -> None: + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """TransformerBlock Module. + Args: + layer_id (int): Identifier for the layer. + model_args (ModelArgs): Model configuration arguments. + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + """ + + def __init__(self, layer_id: int, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """Perform a forward pass through the TransformerBlock. + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + return h + self.feed_forward(self.ffn_norm(h)) + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class ParallelTransformer(nn.Module): + """Transformer Module. + Args: + model_args (ModelArgs): Model configuration arguments. + Attributes: + model_args (ModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + """ + + def __init__(self, model_args: ModelArgs, tp_mesh: DeviceMesh = None): + # Here we use distributed model initialization to avoid memory overflow + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.tok_embeddings.to(model_args.device) + self.tok_embeddings = self.parallel_embeddings(self.tok_embeddings, tp_mesh) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer( + "freqs_cis", + self._precompute_freqs_cis().to(model_args.device), + persistent=True, + ) + + self.layers = torch.nn.ModuleDict().to(model_args.device) + for layer_id in range(model_args.n_layers): + block = TransformerBlock(layer_id, model_args).to(model_args.device) + self.layers[str(layer_id)] = block + self.parallel_transformer_block(self.layers[str(layer_id)], tp_mesh) + + self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps).to( + model_args.device + ) + self.norm = self.parallel_norm(self.norm, tp_mesh) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False).to( + model_args.device + ) + self.output = self.parallel_output(self.output, tp_mesh) + self.init_weights() + + def parallel_transformer_block(self, transformer_block, tp_mesh): + if tp_mesh.size() <= 1: + return + plan = { + "attention": PrepareModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(output_layouts=Shard(1)), + "attention_norm": SequenceParallel(), + "feed_forward": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w3": ColwiseParallel(), + "ffn_norm": SequenceParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + # Apply the plan for the current transformer block + parallelize_module(transformer_block, tp_mesh, plan) + + def parallel_embeddings(self, embedding, tp_mesh): + plan = { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ) + } + return parallelize_module(embedding, tp_mesh, plan) + + def parallel_output(self, output, tp_mesh): + plan = { + "output": ColwiseParallel( + input_layouts=Shard(1), + ), + } + return parallelize_module(output, tp_mesh, plan) + + def parallel_norm(self, norm, tp_mesh): + plan = { + "norm": SequenceParallel(), + } + return parallelize_module(norm, tp_mesh, plan) + + def reset_parameters(self): + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + + def init_weights(self): + """[Note: On ``init_weights`` vs. + ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + layer.init_weights() + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + self.model_args.max_seq_len * 2, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor): + """Perform a forward pass through the Transformer model. + Args: + tokens (torch.Tensor): Input token indices. + Returns: + torch.Tensor: Output logits after applying the Transformer model. + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + return self.output(h).float() if self.output else h + + @classmethod + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + """Initialize a Transformer model from a ModelArgs object. + Args: + model_args (ModelArgs): Model configuration arguments. + Returns: + Transformer: Transformer model. + """ + return cls(model_args) diff --git a/examples/distributed_inference/rotary_embedding.py b/examples/distributed_inference/rotary_embedding.py index 1153ea2180..6c18f9eb8f 100644 --- a/examples/distributed_inference/rotary_embedding.py +++ b/examples/distributed_inference/rotary_embedding.py @@ -84,20 +84,20 @@ def parallel_rotary_block(rotary_block, tp_mesh): "wk": ColwiseParallel(), "wo": RowwiseParallel(output_layouts=Shard(0)), } - rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode + rotary_block.n_parallel = tp_mesh.size() parallelize_module(rotary_block, tp_mesh, plan) class RotaryAttention(nn.Module): - def __init__(self, dim: int, seq_len: int): + def __init__(self, dim: int, seq_len: int, n_parallel: int = 1): super().__init__() self.dim = dim self.wq = nn.Linear(dim, dim) self.wk = nn.Linear(dim, dim) self.wo = nn.Linear(dim, dim) self.seq_len = seq_len - self.n_parallel = 1 + self.n_parallel = n_parallel self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) self.init_weights() diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 98d3ca18e9..068316659e 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -17,29 +17,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh -def find_repo_root(max_depth=10): - dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): - files = os.listdir(dir_path) - if "MODULE.bazel" in files: - return dir_path - else: - dir_path = os.path.dirname(dir_path) - - raise RuntimeError("Could not find repo root") - - -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) - fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) - logger.addHandler(fh) - return logger - - -# This is required for env initialization since we use mpirun -def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): +def initialize_distributed_env(rank=0, world_size=1, port=29500): local_rank = int( os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) ) @@ -50,9 +28,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["TRTLLM_PLUGINS_PATH"] = ( - find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so" - ) # Necessary to assign a device to each rank. torch.cuda.set_device(local_rank) @@ -66,13 +41,12 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) rank = device_mesh.get_rank() assert rank == local_rank - logger = initialize_logger(rank, logger_file_name) device_id = ( rank % torch.cuda.device_count() ) # Ensure each rank gets a unique device torch.cuda.set_device(device_id) - return device_mesh, world_size, rank, logger + return device_mesh, world_size, rank def cleanup_distributed_env(): diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py new file mode 100644 index 0000000000..f8a262ee40 --- /dev/null +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -0,0 +1,72 @@ +# Taken and modified pytorch lightening +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning +import logging +import os +import time + +import torch +import torch.distributed as dist +from llama3_model import ModelArgs, ParallelTransformer +from tensor_parallel_initialize_dist import ( + cleanup_distributed_env, + initialize_distributed_env, +) +from torch.distributed._composable.fsdp import MixedPrecisionPolicy +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, +) + +if not dist.is_initialized(): + initialize_distributed_env() + +import torch_tensorrt +from torch_tensorrt.dynamo.distributed.utils import ( + get_tensor_parallel_device_mesh, + initialize_logger, +) + +device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() +logger = initialize_logger(_rank, "tensor_parallel_llama3") + +logger.info(f"Starting PyTorch TP example on rank {_rank}.") +assert ( + _world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {_world_size} gpus" + +model_args = ModelArgs( + vocab_size=32000, + dim=1024, + n_layers=4, + n_heads=8, + rope_theta=500000.0, + n_kv_heads=8, + device="cuda", +) + +with torch.no_grad(): + model = ParallelTransformer(model_args, device_mesh) + torch.manual_seed(0) + inp = torch.randint(32000, (8, 256), device="cuda") + python_result = model(inp) + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + model = torch.compile( + model, + fullgraph=True, + backend="torch_tensorrt", + options={ + "use_python_runtime": True, + "use_distributed_mode_trace": True, + "debug": True, + }, + dynamic=False, + ) + + start = time.time() + output = model(inp) + end = time.time() + logger.info(f"Compilation time is {end-start}") + assert (python_result - output).std() < 0.01, "Compilation result is not correct." + + cleanup_distributed_env() diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py index da3f3fd8fd..b5412b336b 100644 --- a/examples/distributed_inference/tensor_parallel_rotary_embedding.py +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -14,17 +14,25 @@ import time import torch -import torch_tensorrt -from rotary_embedding import RotaryAttention, parallel_rotary_block +import torch.distributed as dist from tensor_parallel_initialize_dist import ( cleanup_distributed_env, initialize_distributed_env, ) -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_rotary_embedding" +if not dist.is_initialized(): + initialize_distributed_env() + +import torch_tensorrt +from torch_tensorrt.dynamo.distributed.utils import ( + get_tensor_parallel_device_mesh, + initialize_logger, ) +device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() +logger = initialize_logger(_rank, "tensor_parallel_rotary_embedding") + +from rotary_embedding import RotaryAttention, parallel_rotary_block """ This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning @@ -37,7 +45,7 @@ DIM = 128 with torch.no_grad(): - model = RotaryAttention(DIM, SEQ_LEN) + model = RotaryAttention(DIM, SEQ_LEN, device_mesh.size()) parallel_rotary_block(model, device_mesh) device = torch.device("cuda", device_mesh.get_rank()) model.to(device) diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index c5688c6e5b..ca0ecaf9a1 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -25,22 +25,29 @@ import torch import torch.distributed as dist import torch.nn as nn -import torch_tensorrt from tensor_parallel_initialize_dist import ( cleanup_distributed_env, initialize_distributed_env, ) + +if not dist.is_initialized(): + initialize_distributed_env() +import torch_tensorrt from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, parallelize_module, ) - -device_mesh, _world_size, _rank, logger = initialize_distributed_env( - "./tensor_parallel_simple_example" +from torch_tensorrt.dynamo.distributed.utils import ( + get_tensor_parallel_device_mesh, + initialize_logger, ) +device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() +logger = initialize_logger(_rank, "tensor_parallel_simple_example") + + """ This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """ diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0dc4654db0..e2ed813cec 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -13,7 +13,6 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt._features import needs_cross_compile -from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults, partitioning from torch_tensorrt.dynamo._DryRunTracker import ( DryRunTracker, @@ -103,6 +102,7 @@ def cross_compile_for_windows( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -176,6 +176,7 @@ def cross_compile_for_windows( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -285,7 +286,6 @@ def cross_compile_for_windows( arg_inputs = [arg_inputs] # type: ignore # Prepare torch_trt inputs - trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} @@ -330,6 +330,7 @@ def cross_compile_for_windows( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "use_distributed_mode_trace": use_distributed_mode_trace, } # disable the following settings is not supported for cross compilation for windows feature @@ -374,7 +375,6 @@ def cross_compile_for_windows( ) trt_gm = compile_module( gm, - trt_arg_inputs, trt_kwarg_inputs, settings, ) @@ -430,6 +430,7 @@ def compile( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -506,6 +507,7 @@ def compile( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -618,7 +620,6 @@ def compile( arg_inputs = [arg_inputs] # type: ignore # Prepare torch_trt inputs - trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} @@ -674,6 +675,7 @@ def compile( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) @@ -703,16 +705,13 @@ def compile( logger.warning( "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True" ) - trt_gm = compile_module( - gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache - ) + trt_gm = compile_module(gm, trt_kwarg_inputs, settings, engine_cache) return trt_gm @fn_supports_debugger # type: ignore[misc] def compile_module( gm: torch.fx.GraphModule, - sample_arg_inputs: Sequence[Input], sample_kwarg_inputs: Optional[dict[Any, Any]] = None, settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, @@ -1045,6 +1044,7 @@ def convert_exported_program_to_serialized_trt_engine( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1118,6 +1118,7 @@ def convert_exported_program_to_serialized_trt_engine( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model. **kwargs: Any, Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs @@ -1286,6 +1287,7 @@ def convert_exported_program_to_serialized_trt_engine( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index c39fe57197..cb25a105f2 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -22,7 +22,6 @@ from torch_tensorrt.dynamo.utils import ( is_tegra_platform, parse_dynamo_kwargs, - prepare_inputs, set_log_level, ) @@ -150,9 +149,6 @@ def _pretraced_backend( logger.debug("Lowered Input graph:\n " + str(gm.graph)) - torchtrt_inputs = prepare_inputs( - torch_inputs, disable_memory_format_check=True - ) if settings.require_full_compilation: logger.warning( "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" @@ -163,7 +159,6 @@ def _pretraced_backend( ) trt_compiled = compile_module( gm, - torchtrt_inputs, settings=settings, engine_cache=engine_cache, ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 3828f97f99..094de488ec 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,8 +1,6 @@ import collections -import ctypes import functools import logging -import os from typing import ( Any, Callable, @@ -1124,69 +1122,6 @@ def args_bounds_check( return args[i] if len(args) > i and args[i] is not None else replacement -def load_tensorrt_llm() -> bool: - """ - Attempts to load the TensorRT-LLM plugin and initialize it. - - Returns: - bool: True if the plugin was successfully loaded and initialized, False otherwise. - """ - try: - import tensorrt_llm as trt_llm # noqa: F401 - - _LOGGER.info("TensorRT-LLM successfully imported") - return True - except (ImportError, AssertionError) as e_import_error: - # Check for environment variable for the plugin library path - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - if not plugin_lib_path: - _LOGGER.warning( - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", - ) - return False - - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") - try: - # Load the shared library - handle = ctypes.CDLL(plugin_lib_path) - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") - except OSError as e_os_error: - _LOGGER.error( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" - f"Ensure the path is correct and the library is compatible", - exc_info=e_os_error, - ) - return False - - try: - # Configure plugin initialization arguments - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - handle.initTrtLlmPlugins.restype = ctypes.c_bool - except AttributeError as e_plugin_unavailable: - _LOGGER.warning( - "Unable to initialize the TensorRT-LLM plugin library", - exc_info=e_plugin_unavailable, - ) - return False - - try: - # Initialize the plugin - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): - _LOGGER.info("TensorRT-LLM plugin successfully initialized") - return True - else: - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") - return False - except Exception as e_initialization_error: - _LOGGER.warning( - "Exception occurred during TensorRT-LLM plugin library initialization", - exc_info=e_initialization_error, - ) - return False - return False - - def promote_trt_tensors_to_same_dtype( ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str ) -> tuple[TRTTensor, TRTTensor]: diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 1442c2b17b..045fa6b149 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -11,15 +11,15 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) -from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm +from torch_tensorrt.dynamo.distributed.utils import load_tensorrt_llm_for_nccl +from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + tensorrt_fused_nccl_all_gather_op, + tensorrt_fused_nccl_reduce_scatter_op, +) _LOGGER: logging.Logger = logging.getLogger(__name__) -if load_tensorrt_llm(): - from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( - tensorrt_fused_nccl_all_gather_op, - tensorrt_fused_nccl_reduce_scatter_op, - ) +if load_tensorrt_llm_for_nccl(): @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) def fused_nccl_gather( diff --git a/py/torch_tensorrt/dynamo/distributed/__init__.py b/py/torch_tensorrt/dynamo/distributed/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/dynamo/distributed/utils.py b/py/torch_tensorrt/dynamo/distributed/utils.py new file mode 100644 index 0000000000..ad217a09af --- /dev/null +++ b/py/torch_tensorrt/dynamo/distributed/utils.py @@ -0,0 +1,309 @@ +import ctypes +import getpass +import logging +import os +import platform +import tempfile +import urllib.request +from pathlib import Path +from typing import Optional + +import torch +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +from torch_tensorrt._version import __tensorrt_llm_version__ + +_WHL_CPYTHON_VERSION = "cp310" + +logger = logging.getLogger(__name__) + + +def check_tensor_parallel_device_number(world_size: int) -> None: + if world_size % 2 != 0: + raise ValueError( + f"TP examples require even number of GPUs, but got {world_size} gpus" + ) + + +def get_tensor_parallel_device_mesh( + rank: int = 0, world_size: int = 1 +) -> tuple[DeviceMesh, int, int]: + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank + + +def initialize_logger(rank: int, logger_file_name: str) -> logging.Logger: + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger + + +def is_platform_supported_for_trtllm() -> bool: + """ + Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend. + + Returns: + bool: True if supported, False otherwise. + + Unsupported: + - Windows platforms + - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release) + - CUDA 13 not supported + """ + system = platform.system().lower() + machine = platform.machine().lower() + release = platform.release().lower() + + if "windows" in system: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Windows." + ) + return False + + if machine == "aarch64" and "tegra" in release: + logger.info( + "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices." + ) + return False + + try: + cuda_version = torch.version.cuda # e.g., "12.4" or "13.0" + if cuda_version is None: + logger.warning("No CUDA runtime detected — TRT-LLM plugins unavailable.") + return False + + major, minor = map(int, cuda_version.split(".")) + if major != 12: + logger.warning("CUDA 13 is not supported for TRT-LLM plugins.") + return False + + return True + + except Exception as e: + logger.warning(f"Failed to detect CUDA version: {e}") + return False + + return True + + +def _cache_root() -> Path: + username = getpass.getuser() + return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" + + +def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: + return ( + _cache_root() + / "trtllm" + / f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}" + ) + + +def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: + from torch.distributed import barrier, get_rank, is_initialized + + if not is_initialized(): + # Single process case, just unzip + is_master = True + else: + is_master = get_rank() == 0 # only rank 0 does the unzip + + if is_master: + try: + import zipfile + except ImportError as e: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + try: + with zipfile.ZipFile(wheel_path) as zip_ref: + zip_ref.extractall(extract_dir) + logger.debug(f"Extracted wheel to {extract_dir}") + + except FileNotFoundError as e: + # This should capture the errors in the download failure above + logger.error(f"Wheel file not found at {wheel_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {wheel_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e + + # Make sure others wait until unzip is done + if is_initialized(): + barrier() + + +def download_and_get_plugin_lib_path() -> Optional[str]: + """ + Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. + + Args: + platform (str): Platform identifier (e.g., 'linux_x86_64') + + Returns: + Optional[str]: Path to shared library or None if operation fails. + """ + platform_system = platform.system().lower() + platform_machine = platform.machine().lower() + wheel_filename = ( + f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-" + f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl" + ) + wheel_path = _cache_root() / wheel_filename + extract_dir = _extracted_dir_trtllm(platform_system, platform_machine) + # else will never be met though + lib_filename = ( + "libnvinfer_plugin_tensorrt_llm.so" + if "linux" in platform_system + else "libnvinfer_plugin_tensorrt_llm.dll" + ) + # eg: /tmp/torch_tensorrt_/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so + plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename + + if plugin_lib_path.exists(): + return str(plugin_lib_path) + + wheel_path.parent.mkdir(parents=True, exist_ok=True) + extract_dir.mkdir(parents=True, exist_ok=True) + + if not wheel_path.exists(): + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + download_url = base_url + wheel_filename + try: + logger.debug(f"Downloading {download_url} ...") + urllib.request.urlretrieve(download_url, wheel_path) + logger.debug("Download succeeded and TRT-LLM wheel is now present") + except urllib.error.HTTPError as e: + logger.error( + f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" + ) + except urllib.error.URLError as e: + logger.error( + f"URL error when trying to download {download_url}: {e.reason}" + ) + except OSError as e: + logger.error(f"Local file write error: {e}") + + extract_wheel_file(wheel_path, extract_dir) + + try: + wheel_path.unlink(missing_ok=True) + logger.debug(f"Deleted wheel file: {wheel_path}") + except Exception as e: + logger.warning(f"Could not delete wheel file {wheel_path}: {e}") + if not plugin_lib_path.exists(): + logger.error( + f"Plugin library not found at expected location: {plugin_lib_path}" + ) + return None + + return str(plugin_lib_path) + + +def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: + """ + Loads and initializes the TensorRT-LLM plugin from the given shared library path. + + Args: + plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. + + Returns: + bool: True if successful, False otherwise. + """ + try: + handle = ctypes.CDLL(plugin_lib_path) + logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + if "libmpi" in str(e_os_error): + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", + exc_info=e_os_error, + ) + else: + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"Ensure the path is correct and the library is compatible.", + exc_info=e_os_error, + ) + return False + + try: + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + logger.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False + + try: + if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): + logger.info("TensorRT-LLM plugin successfully initialized") + return True + else: + logger.warning("TensorRT-LLM plugin library failed in initialization") + return False + except Exception as e_initialization_error: + logger.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False + return False + + +def load_tensorrt_llm_for_nccl() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + Either the env variable TRTLLM_PLUGINS_PATH can specify the path + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + if not is_platform_supported_for_trtllm(): + return False + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + + if plugin_lib_path: + return load_and_initialize_trtllm_plugin(plugin_lib_path) + else: + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: + logger.warning( + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" + ) + return False + + plugin_lib_path = download_and_get_plugin_lib_path() + return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type] + return False diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py deleted file mode 100644 index f8ca1f71b9..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py +++ /dev/null @@ -1,105 +0,0 @@ -import logging - -import torch - -logger = logging.getLogger(__name__) - -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, - find_complex_nodes, -) - -from ._replace_complex_placeholder_to_tuple import replace_complex_placeholder_to_tuple - - -def tensorrt_complex_mul(args0, args1): - args0_real, args0_imag = torch.ops.aten.split.Tensor(args0, 1, -1) - args1_real, args1_imag = torch.ops.aten.split.Tensor(args1, 1, -1) - - args0_real = torch.ops.aten.squeeze.dim(args0_real, -1) - args0_imag = torch.ops.aten.squeeze.dim(args0_imag, -1) - args1_real = torch.ops.aten.squeeze.dim(args1_real, -1) - args1_imag = torch.ops.aten.squeeze.dim(args1_imag, -1) - - complex_mul_real = torch.ops.aten.sub( - torch.ops.aten.mul(args0_real, args1_real), - torch.ops.aten.mul(args0_imag, args1_imag), - ) - complex_mul_imag = torch.ops.aten.add( - torch.ops.aten.mul(args0_real, args1_imag), - torch.ops.aten.mul(args0_imag, args1_real), - ) - - return torch.ops.aten.stack((complex_mul_real, complex_mul_imag), -1) - - -def remove_complex_real_view_nodes(gm: torch.fx.GraphModule): - modified_graph = False - nodes_to_remove = [] - for node in gm.graph.nodes: - if "view_as_complex" in node.name or "view_as_real" in node.name: - nodes_to_remove.append(node) - - for node in nodes_to_remove: - input_node = node.args[0] if node.args else None - - for other_node in gm.graph.nodes: - new_args = tuple( - input_node if arg is node else arg for arg in other_node.args - ) - other_node.args = new_args - - gm.graph.erase_node(node) - modified_graph = True - - if modified_graph: - gm = clean_up_graph_after_modifications(gm) - logger.debug( - f"Graph after removing view_as_complex nodes and view_as_real nodes:\n{gm.graph}" - ) - - -def modify_reshape_nodes(gm: torch.fx.GraphModule, complex_nodes): - for node in gm.graph.nodes: - if node in complex_nodes: - # slice and transpose will remain same - if "reshape" in node.name: - new_shape = list(node.args[1]) + [2] - node.args = (node.args[0], tuple(new_shape)) - - -def modify_mul_nodes(gm: torch.fx.GraphModule, complex_nodes): - modified_graph = False - for node in gm.graph.nodes: - if node in complex_nodes: - if "mul" in node.name: - complex_mul_args = (node.args[0], node.args[1]) - with gm.graph.inserting_after(node): - replacement_node = gm.graph.create_node( - op="call_function", - target=tensorrt_complex_mul, - args=complex_mul_args, - ) - node.replace_all_uses_with(replacement_node) - replacement_node.meta.update(node.meta) - modified_graph = True - gm.graph.erase_node(node) - - if modified_graph: - gm = clean_up_graph_after_modifications(gm) - logger.debug( - f"Graph after custom complex mul nodes is applied to the graph:\n{gm.graph}" - ) - - -def modify_complex_nodes(gm: torch.fx.GraphModule, complex_nodes): - modify_reshape_nodes(gm, complex_nodes) - remove_complex_real_view_nodes(gm) - modify_mul_nodes(gm, complex_nodes) - - -def modify_reshape_complex_nodes(gm: torch.fx.GraphModule, complexInputIndices): - complex_nodes = find_complex_nodes(gm) - if complex_nodes: - replace_complex_placeholder_to_tuple(gm, complexInputIndices) - modify_complex_nodes(gm, complex_nodes) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py b/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py deleted file mode 100644 index e2edec3d28..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py +++ /dev/null @@ -1,112 +0,0 @@ -import logging -from typing import List, Tuple - -import torch -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.node import _get_qualified_name -from torch_tensorrt.dynamo.conversion.converter_utils import args_bounds_check - -# dead-code elimination, linting, and recompilation for graph, in-place -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -logger = logging.getLogger(__name__) - - -def replace_complex_placeholder_to_tuple( - gm: torch.fx.GraphModule, - inputListindices: List[int], -) -> torch.fx.GraphModule: - modified_graph = False - input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices] - for node in gm.graph.nodes: - if node.op == "placeholder" and node.target in input_arg_list: - from torch._subclasses.fake_tensor import FakeTensorMode - - node_shape = node.meta["val"].size() - new_node_shape = node_shape + (2,) - new_node_dtype = None - if node.meta["val"].dtype == torch.complex64: - new_node_dtype = torch.float32 - else: - new_node_dtype = torch.float64 - fake_mode = FakeTensorMode() - - real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype) - with FakeTensorMode() as fake_mode: - new_placeholder_tuple = fake_mode.from_tensor(real_tensor) - node.meta["val"] = new_placeholder_tuple - modified_graph = True - # propagate the meta data change for the downstream ops - # TODO:to check if this is required in all cases - propogate_complex_num_shape_change_till_complex_mul(gm, node, fake_mode) - - # If graph was modified, clean it up - if modified_graph: - gm = clean_up_graph_after_modifications(gm) - logger.debug( - f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}" - ) - - return gm - - -def infer_slice_shape(node: torch.fx.Node) -> Tuple[int, ...]: - input_shape = node.args[0].meta["val"].shape - slice_args = node.args - dim = slice_args[1] - start = slice_args[2] - end = slice_args[3] - step = args_bounds_check(slice_args, 4, replacement=1) - new_shape = list(input_shape) - new_shape[dim] = (end - start + step - 1) // step - return tuple(new_shape) - - -def infer_reshape_shape(node: torch.fx.Node) -> torch.fx.node.Argument: - return node.args[1] - - -shape_inference_funcs = { - "torch.ops.aten.slice.Tensor": infer_slice_shape, - "torch.ops.aten.reshape.default": infer_reshape_shape, -} - - -# Please note this function is for the use case of Llama model -# with complex placeholder->reshape->slice->complex mul -# Hence mul is the terminating op -def propogate_complex_num_shape_change_till_complex_mul( - node: torch.fx.Node, start_node: torch.fx.Node, fake_mode: FakeTensorMode -) -> None: - visited_nodes = set() - stack = [start_node] - while stack: - node = stack.pop() - if node in visited_nodes: - continue - visited_nodes.add(node) - update_node_meta(node, fake_mode) - for user in node.users: - if ( - user.op == "call_function" - and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor" - ): - continue - stack.append(user) - - -def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None: - op_name = node.name - op_target = node.target - - if node.op == "call_function": - op_target = _get_qualified_name(node.target) - - if op_target in shape_inference_funcs: - new_shape = shape_inference_funcs[op_target](node) - real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype) - node.meta["val"] = fake_mode.from_tensor(real_tensor) - else: - print("No shape for the inference function", {op_name}) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py index c3ead218aa..23cedc2211 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py @@ -108,6 +108,7 @@ class ComplexGraphRewriter: def __init__(self, gm: GraphModule, truncate_double: bool = False) -> None: self.gm = gm self.truncate_double = truncate_double + self.processed_input_nodes = set() def extract_shape_dtype_device( self, input_node: Node @@ -185,8 +186,12 @@ def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None: for subgraph in subgraphs: for input_node in subgraph.input_nodes: logger.debug(f"Input node rewrite: {input_node.name}") + if input_node in self.processed_input_nodes: + logger.debug(f"Skipping {input_node.name}, already processed.") + continue if input_node.op not in ("call_function"): self.replace_input_node(input_node) + self.processed_input_nodes.add(input_node) for node in subgraph.subgraph_nodes: logger.debug(f"Subgraph Node rewrite: {node.name}") if node.target == torch.ops.aten.view_as_complex.default: @@ -230,6 +235,17 @@ def match_complex_mul( # type: ignore[no-untyped-def] elif node.target == torch.ops.aten.view_as_real.default: node.replace_all_uses_with(node.args[0]) self.gm.graph.erase_node(node) + elif node.target == torch.ops.aten._reshape_copy.default: + old_shape = node.args[1] + if isinstance(old_shape, (list, tuple)) and all( + isinstance(x, int) for x in old_shape + ): + new_shape = list(old_shape) + [2] + node.args = (node.args[0], new_shape) + logger.debug( + f"Updated reshape {node.name} from {old_shape} to {new_shape}" + ) + modified = True else: logger.debug(f"Unsupported node target: {node.target}") logger.debug( diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 9e54fbac3d..9290bf7909 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -8,6 +8,7 @@ from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.runtime.utils import complex_to_ri_stacked_tensor logger = logging.getLogger(__name__) @@ -142,6 +143,9 @@ def forward( for i, _ in enumerate(inputs): if not contiguous_inputs[i].is_cuda: + contiguous_inputs[i] = complex_to_ri_stacked_tensor( + contiguous_inputs[i] + ) logger.warning( f"Detected input[{i}] is not on a cuda device. " "This tensor is being moved by the runtime but for performance considerations, " diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..cd74bddd46 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -14,6 +14,7 @@ from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger +from torch_tensorrt.dynamo.runtime.utils import complex_to_ri_stacked_tensor from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( @@ -358,6 +359,7 @@ def setup_input_tensors( need_cudagraphs_record: bool, ) -> None: for i, input_name in enumerate(self.input_names): + contiguous_inputs[i] = complex_to_ri_stacked_tensor(contiguous_inputs[i]) if not contiguous_inputs[i].is_cuda: logger.warning( f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..f238532a51 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -15,6 +15,7 @@ needs_torch_tensorrt_runtime, ) from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.runtime.utils import complex_to_ri_stacked_tensor logger = logging.getLogger(__name__) @@ -320,6 +321,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: # directly cast the input to a Torch Tensor. # # This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors + inputs = tuple(complex_to_ri_stacked_tensor(i) for i in inputs) input_tensors: List[torch.Tensor] = [ (i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) for i in inputs diff --git a/py/torch_tensorrt/dynamo/runtime/utils.py b/py/torch_tensorrt/dynamo/runtime/utils.py new file mode 100644 index 0000000000..ad391b66b1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/utils.py @@ -0,0 +1,8 @@ +import torch + + +def complex_to_ri_stacked_tensor(t: torch.Tensor) -> torch.Tensor: + # Converts complex tensor to real/imag stack + if torch.is_complex(t): + return torch.stack([t.real, t.imag], dim=-1) + return t diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 564250e5ae..6bc656732c 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,7 +5,16 @@ import warnings from dataclasses import fields, replace from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import sympy @@ -34,6 +43,7 @@ RTOL = 5e-3 ATOL = 5e-3 CPU_DEVICE = "cpu" +_WHL_CPYTHON_VERSION = "cp310" class Frameworks(Enum): @@ -90,11 +100,9 @@ def unified_dtype_converter( ) -> Union[np.dtype, torch.dtype, TRTDataType]: """ Convert TensorRT, Numpy, or Torch data types to any other of those data types. - Args: dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type. to (Frameworks): The framework to convert the data type to. - Returns: The equivalent data type in the requested framework. """ diff --git a/setup.py b/setup.py index 291cfe9b97..e29689de21 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ __cuda_version__: str = "0.0" __tensorrt_version__: str = "0.0" __tensorrt_rtx_version__: str = "0.0" +__tensorrt_llm_version__: str = "0.0" LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") # CI_PIPELINE_ID is the environment variable set by DLFW ci build @@ -69,6 +70,7 @@ def load_dep_info(): global __cuda_version__ global __tensorrt_version__ global __tensorrt_rtx_version__ + global __tensorrt_llm_version__ with open("dev_dep_versions.yml", "r") as stream: versions = yaml.safe_load(stream) if (gpu_arch_version := os.environ.get("CU_VERSION")) is not None: @@ -79,6 +81,7 @@ def load_dep_info(): __cuda_version__ = versions["__cuda_version__"] __tensorrt_version__ = versions["__tensorrt_version__"] __tensorrt_rtx_version__ = versions["__tensorrt_rtx_version__"] + __tensorrt_llm_version__ = versions["__tensorrt_llm_version__"] load_dep_info() @@ -249,6 +252,7 @@ def gen_version_file(): f.write('__cuda_version__ = "' + __cuda_version__ + '"\n') f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n') f.write('__tensorrt_rtx_version__ = "' + __tensorrt_rtx_version__ + '"\n') + f.write('__tensorrt_llm_version__ = "' + __tensorrt_llm_version__ + '"\n') def copy_libtorchtrt(multilinux=False, rt_only=False): @@ -450,6 +454,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary", "torch_tensorrt.dynamo.conversion.plugins", "torch_tensorrt.dynamo.debug", + "torch_tensorrt.dynamo.distributed", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py index e3062249fa..b13a07d308 100644 --- a/tests/py/dynamo/distributed/distributed_utils.py +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -1,5 +1,6 @@ import logging import os +import random import numpy as np import tensorrt as trt @@ -8,25 +9,19 @@ from torch.distributed._tensor.device_mesh import init_device_mesh -def set_environment_variables_pytest(): +def set_environment_variables_pytest_single_process(): + port = 29500 + random.randint(1, 1000) os.environ["WORLD_SIZE"] = str(1) os.environ["RANK"] = str(0) os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(29500) - os.environ["USE_TRTLLM_PLUGINS"] = "1" - - -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) - fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) - logger.addHandler(fh) - return logger + os.environ["MASTER_PORT"] = str(port) -# This is required for env initialization since we use mpirun -def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): +def set_environment_variables_pytest_multi_process( + rank: int = 0, world_size: int = 1 +) -> None: + port = 29500 + random.randint(1, 1000) + # these variables are set by mpirun -n 2 local_rank = int( os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) ) @@ -37,7 +32,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so" # Necessary to assign a device to each rank. torch.cuda.set_device(local_rank) @@ -47,14 +41,3 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 # set a manual seed for reproducibility torch.manual_seed(1111) - - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) - rank = device_mesh.get_rank() - assert rank == local_rank - logger = initialize_logger(rank, logger_file_name) - device_id = ( - rank % torch.cuda.device_count() - ) # Ensure each rank gets a unique device - torch.cuda.set_device(device_id) - - return device_mesh, world_size, rank, logger diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index 89c94300b7..79c11bdeab 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -1,42 +1,87 @@ import os +import unittest import torch import torch.distributed as dist import torch.nn as nn -from distributed_utils import set_environment_variables_pytest +from conversion.harness import DispatchTestCase + +# The distributed env initialization has to be before torchTRT import since it uses barrier +from distributed_utils import ( + set_environment_variables_pytest_multi_process, + set_environment_variables_pytest_single_process, +) from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -set_environment_variables_pytest() -dist.init_process_group(backend="nccl", init_method="env://") -group = dist.new_group(ranks=[0]) -group_name = group.group_name -world_size = 1 +if "OMPI_COMM_WORLD_SIZE" in os.environ: + set_environment_variables_pytest_multi_process() +else: + set_environment_variables_pytest_single_process() -from conversion.harness import DispatchTestCase +if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", + init_method="env://", + ) +from torch_tensorrt.dynamo.distributed.utils import is_platform_supported_for_trtllm -class TestGatherNcclOpsConverter(DispatchTestCase): - @parameterized.expand([8]) - def test_nccl_ops(self, linear_layer_dim): - class DistributedGatherModel(nn.Module): - def __init__(self, input_dim): - super().__init__() - self.fc = torch.nn.Linear(input_dim, input_dim) - - def forward(self, x): - x = self.fc(x) - gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( - x, world_size, group_name - ) - gathered_tensor = torch.ops._c10d_functional.wait_tensor( - gathered_tensor - ) - return gathered_tensor +class DistributedGatherModel(nn.Module): + def __init__(self, input_dim, world_size, group_name): + super().__init__() + self.fc = nn.Linear(input_dim, input_dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x): + x = self.fc(x) + gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( + x, self.world_size, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor(gathered_tensor) + + +class DistributedReduceScatterModel(nn.Module): + def __init__(self, input_dim, world_size, group_name): + super().__init__() + self.fc = nn.Linear(input_dim, input_dim) + self.world_size = world_size + self.group_name = group_name + + def forward(self, x): + x = self.fc(x) + out = torch.ops._c10d_functional.reduce_scatter_tensor( + x, "sum", self.world_size, self.group_name + ) + return torch.ops._c10d_functional.wait_tensor(out) + + +class TestNcclOpsConverter(DispatchTestCase): + @unittest.skipIf( + not is_platform_supported_for_trtllm(), + "Skipped on Windows, Jetson and CUDA13: NCCL backend is not supported.", + ) + @classmethod + def setUpClass(cls): + cls.world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + cls.rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) + cls.group = dist.new_group(ranks=list(range(cls.world_size))) + cls.group_name = cls.group.group_name + + @classmethod + def tearDownClass(cls): + if dist.is_initialized(): + dist.destroy_process_group() + + @parameterized.expand([8]) + def test_nccl_ops_gather(self, linear_layer_dim): inputs = [torch.randn(1, linear_layer_dim).to("cuda")] self.run_test( - DistributedGatherModel(linear_layer_dim).cuda(), + DistributedGatherModel( + linear_layer_dim, self.world_size, self.group_name + ).cuda(), inputs, use_dynamo_tracer=True, enable_passes=True, @@ -44,28 +89,11 @@ def forward(self, x): @parameterized.expand([8]) def test_nccl_ops_scatter(self, linear_layer_dim): - - class DistributedReduceScatterModel(nn.Module): - def __init__(self, input_dim): - super().__init__() - self.fc = torch.nn.Linear(input_dim, input_dim) - - def forward(self, x): - x = self.fc(x) - scatter_reduce_tensor = ( - torch.ops._c10d_functional.reduce_scatter_tensor( - x, "sum", world_size, group_name - ) - ) - scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor( - scatter_reduce_tensor - ) - return scatter_reduce_tensor - inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] - self.run_test( - DistributedReduceScatterModel(linear_layer_dim).cuda(), + DistributedReduceScatterModel( + linear_layer_dim, self.world_size, self.group_name + ).cuda(), inputs, use_dynamo_tracer=True, enable_passes=True, diff --git a/tests/py/dynamo/distributed/test_nccl_ops.sh b/tests/py/dynamo/distributed/test_nccl_ops.sh index dd54700048..677d0cb9bc 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.sh +++ b/tests/py/dynamo/distributed/test_nccl_ops.sh @@ -70,51 +70,6 @@ ensure_pytest_installed(){ echo "Setting up the environment" -OS="$(uname -s)" -ARCH="$(uname -m)" - - -#getting the file name for TensorRT-LLM download -if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then - FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl" -elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then - FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl" -else: - echo "Unsupported platform: OS=$OS ARCH=$ARCH - exit 1 -fi - -# Download the selected file -URL="https://pypi.nvidia.com/tensorrt-llm/$FILE" -echo "Downloading $FILE from $URL..." - -#Installing wget -ensure_installed wget - -#Downloading the file -filename=$(basename "$URL") -if [ -f "$filename" ]; then - echo "File already exists: $filename" -else - wget "$URL" -fi -echo "Download complete: $FILE" - -UNZIP_DIR="tensorrt_llm_unzip" -if [[ ! -d "$UNZIP_DIR" ]]; then - echo "Creating directory: $UNZIP_DIR" - mkdir -p "$UNZIP_DIR" - echo "extracting $FILE to $UNZIP_DIR ..." - #Installing unzip - ensure_installed unzip - #unzip the TensorRT-LLM package - unzip -q "$FILE" -d "$UNZIP_DIR" - echo "Unzip complete" -fi - - -export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" -echo ${TRTLLM_PLUGINS_PATH} ensure_mpi_installed libmpich-dev ensure_mpi_installed libopenmpi-dev @@ -123,7 +78,7 @@ run_tests() { cd .. export PYTHONPATH=$(pwd) echo "Running pytest on distributed/test_nccl_ops.py..." - pytest distributed/test_nccl_ops.py + USE_TRTLLM_PLUGINS=1 pytest distributed/test_nccl_ops.py } run_mpi_tests(){