|
1 | 1 | # Taken and modified pytorch lightening
|
2 | 2 | # https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
|
| 3 | +# Taken and modified pytorch lightening |
| 4 | +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning |
3 | 5 | import logging
|
4 | 6 | import os
|
5 | 7 | import time
|
6 | 8 |
|
7 | 9 | import torch
|
8 |
| -import torch_tensorrt |
| 10 | +import torch.distributed as dist |
9 | 11 | from llama3_model import ModelArgs, ParallelTransformer
|
| 12 | +from tensor_parallel_initialize_dist import ( |
| 13 | + cleanup_distributed_env, |
| 14 | + initialize_distributed_env, |
| 15 | +) |
10 | 16 | from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
11 | 17 | from torch.distributed._composable.fsdp.fully_shard import fully_shard
|
12 | 18 | from torch.distributed._tensor import Replicate, Shard
|
13 | 19 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
14 | 20 | checkpoint_wrapper,
|
15 | 21 | )
|
| 22 | + |
| 23 | +if not dist.is_initialized(): |
| 24 | + initialize_distributed_env() |
| 25 | + |
| 26 | +import torch_tensorrt |
16 | 27 | from torch_tensorrt.dynamo.distributed.utils import (
|
17 |
| - cleanup_distributed_env, |
18 | 28 | get_tensor_parallel_device_mesh,
|
19 |
| - initialize_distributed_env, |
20 | 29 | initialize_logger,
|
21 | 30 | )
|
22 | 31 |
|
23 |
| -if not dist.is_initialized(): |
24 |
| - initialize_distributed_env() |
25 |
| - |
26 | 32 | device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
|
27 |
| -logger = initialize_logger(_rank, "tensor_parallel_simple_example") |
| 33 | +logger = initialize_logger(_rank, "tensor_parallel_llama3") |
28 | 34 |
|
29 | 35 | logger.info(f"Starting PyTorch TP example on rank {_rank}.")
|
30 | 36 | assert (
|
|
0 commit comments