Skip to content

Commit 03702a4

Browse files
committed
add hta
Test Plan: ``` $ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 $ USE_NCCL=True LOG_LEVEL=DEBUG RUST_LOG=error USE_STREAMING=True torchx run ./torchft/torchx.py:hsdp --script='train_diloco.py' $ python trace_analysis.py ```
1 parent 855bcad commit 03702a4

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

trace_analysis.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
from hta.trace_analysis import TraceAnalysis
10+
11+
_PROFILES_DIR = "output/replica-0/profiles/step-120"
12+
13+
14+
def main():
15+
analyzer = TraceAnalysis(trace_dir=_PROFILES_DIR)
16+
cp_graph, success = analyzer.critical_path_analysis(
17+
rank=0, annotation="", instance_id=None
18+
)
19+
if not success:
20+
print("Critical path analysis failed")
21+
return
22+
analyzer.overlay_critical_path_analysis(0, cp_graph, output_dir=_PROFILES_DIR)
23+
24+
25+
if __name__ == "__main__":
26+
main()

train_diloco.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
@record
4444
def main() -> None:
4545
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
46+
RANK = int(os.environ.get("RANK", 0))
4647
RUN = int(os.environ.get("RUN", 0))
4748

4849
output_folder = f"output/replica-{REPLICA_GROUP_ID}"
@@ -177,11 +178,11 @@ def forward(self, x):
177178
print(f"Total number of parameters: {num_params}")
178179

179180
def trace_handler(p):
180-
dir = f"{output_folder}/profiles"
181+
dir = f"{output_folder}/profiles/step-{p.step_num}"
181182
if not os.path.exists(dir):
182183
os.makedirs(dir, exist_ok=True)
183184

184-
p.export_chrome_trace(f"{dir}/step-{p.step_num}.json")
185+
p.export_chrome_trace(f"{dir}/rank-{RANK}.json")
185186

186187
# You can use an epoch based training but with faults it's easier to use step
187188
# based training.
@@ -190,6 +191,9 @@ def trace_handler(p):
190191
on_trace_ready=trace_handler,
191192
record_shapes=False,
192193
profile_memory=False,
194+
experimental_config=torch.profiler._ExperimentalConfig( # type: ignore
195+
enable_cuda_sync_events=True
196+
),
193197
)
194198

195199
prof.start()

0 commit comments

Comments
 (0)