Skip to content

Commit bb6ec7a

Browse files
committed
Llama torchTRT lib and env initialization reorg
1 parent f00d349 commit bb6ec7a

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

examples/distributed_inference/tensor_parallel_llama3.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,36 @@
11
# Taken and modified pytorch lightening
22
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
3+
# Taken and modified pytorch lightening
4+
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
35
import logging
46
import os
57
import time
68

79
import torch
8-
import torch_tensorrt
10+
import torch.distributed as dist
911
from llama3_model import ModelArgs, ParallelTransformer
12+
from tensor_parallel_initialize_dist import (
13+
cleanup_distributed_env,
14+
initialize_distributed_env,
15+
)
1016
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1117
from torch.distributed._composable.fsdp.fully_shard import fully_shard
1218
from torch.distributed._tensor import Replicate, Shard
1319
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1420
checkpoint_wrapper,
1521
)
22+
23+
if not dist.is_initialized():
24+
initialize_distributed_env()
25+
26+
import torch_tensorrt
1627
from torch_tensorrt.dynamo.distributed.utils import (
17-
cleanup_distributed_env,
1828
get_tensor_parallel_device_mesh,
19-
initialize_distributed_env,
2029
initialize_logger,
2130
)
2231

23-
if not dist.is_initialized():
24-
initialize_distributed_env()
25-
2632
device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
27-
logger = initialize_logger(_rank, "tensor_parallel_simple_example")
33+
logger = initialize_logger(_rank, "tensor_parallel_llama3")
2834

2935
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
3036
assert (

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,8 @@ def setup_input_tensors(
359359
need_cudagraphs_record: bool,
360360
) -> None:
361361
for i, input_name in enumerate(self.input_names):
362+
contiguous_inputs[i] = complex_to_ri_stacked_tensor(contiguous_inputs[i])
362363
if not contiguous_inputs[i].is_cuda:
363-
contiguous_inputs[i] = complex_to_ri_stacked_tensor(
364-
contiguous_inputs[i]
365-
)
366364
logger.warning(
367365
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
368366
"This tensor is being moved by the runtime but for performance considerations, "

0 commit comments

Comments
 (0)