Skip to content

Commit c952de9

Browse files
committed
add hta
1 parent 855bcad commit c952de9

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

trace_analysis.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from hta.trace_analysis import TraceAnalysis
9+
10+
_PROFILES_DIR = "output/replica-0/profiles/step-120"
11+
12+
def main():
13+
analyzer = TraceAnalysis(trace_dir = _PROFILES_DIR)
14+
cp_graph, success = analyzer.critical_path_analysis(rank=0, annotation="", instance_id=None)
15+
if not success:
16+
print("Critical path analysis failed")
17+
return
18+
analyzer.overlay_critical_path_analysis(
19+
0, cp_graph, output_dir=_PROFILES_DIR)
20+
21+
if __name__ == "__main__":
22+
main()

train_diloco.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
@record
4444
def main() -> None:
4545
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
46+
RANK = int(os.environ.get("RANK", 0))
4647
RUN = int(os.environ.get("RUN", 0))
4748

4849
output_folder = f"output/replica-{REPLICA_GROUP_ID}"
@@ -177,11 +178,11 @@ def forward(self, x):
177178
print(f"Total number of parameters: {num_params}")
178179

179180
def trace_handler(p):
180-
dir = f"{output_folder}/profiles"
181+
dir = f"{output_folder}/profiles/step-{p.step_num}"
181182
if not os.path.exists(dir):
182183
os.makedirs(dir, exist_ok=True)
183184

184-
p.export_chrome_trace(f"{dir}/step-{p.step_num}.json")
185+
p.export_chrome_trace(f"{dir}/rank-{RANK}.json")
185186

186187
# You can use an epoch based training but with faults it's easier to use step
187188
# based training.
@@ -190,6 +191,9 @@ def trace_handler(p):
190191
on_trace_ready=trace_handler,
191192
record_shapes=False,
192193
profile_memory=False,
194+
experimental_config=torch.profiler._ExperimentalConfig( # type: ignore
195+
enable_cuda_sync_events=True
196+
),
193197
)
194198

195199
prof.start()

0 commit comments

Comments
 (0)