Skip to content

Commit 89dd3cb

Browse files
committed
Update
[ghstack-poisoned]
1 parent c4e82e9 commit 89dd3cb

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

scripts/estimate/estimation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
1616
from torch.testing._internal.distributed.fake_pg import FakeStore
1717

18+
from torchtitan.components.ft import init_ft_manager
1819
from torchtitan.components.optimizer import build_lr_schedulers, build_optimizers
1920
from torchtitan.config_manager import JobConfig
2021
from torchtitan.distributed import ParallelDims, utils as dist_utils
@@ -102,7 +103,6 @@ def estimate_memory(job_config: JobConfig):
102103
if not job_config.memory_estimation.disable_fake_mode
103104
else contextlib.nullcontext()
104105
):
105-
106106
logger.info(
107107
f"Building {train_spec.name} {job_config.model.flavor} with {model_config}"
108108
)
@@ -122,7 +122,8 @@ def estimate_memory(job_config: JobConfig):
122122
model.train()
123123

124124
# build optimizer after applying parallelisms to the model
125-
optimizers = build_optimizers([model], job_config)
125+
ft_manager = init_ft_manager(job_config)
126+
optimizers = build_optimizers([model], job_config, ft_manager)
126127
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
127128
# Post optimizer step model converters hook.
128129
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2

0 commit comments

Comments
 (0)