Skip to content

Commit fac5d51

Browse files
prajjwal1facebook-github-bot
authored andcommitted
fix RecMetrics loading (make trained_batches a buffer) (#3534)
Summary: This diff addresses the following task: T209753398 Currently `trained_batches` is not stored in state_dict, requiring us to manually sync this variable upon checkpoint loading. We make this variable a buffer so that it can now be captured with model state dict. Differential Revision: D86697665
1 parent 4cb39c1 commit fac5d51

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,9 @@ def sync(self) -> None:
474474
)
475475
self.comms_module.load_pre_compute_states(aggregated_states)
476476

477+
# Sync _trained_batches to comms module
478+
self.comms_module._trained_batches.copy_(self._trained_batches)
479+
477480
logger.info("CPUOffloadedRecMetricModule synced.")
478481

479482
@override

torchrec/metrics/metric_module.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ def __init__(
202202
self.rec_metrics = rec_metrics if rec_metrics else RecMetricList([])
203203
self.throughput_metric = throughput_metric
204204
self.state_metrics = state_metrics if state_metrics else {}
205-
self.trained_batches: int = 0
205+
206+
self.register_buffer("_trained_batches", torch.tensor(0), persistent=True)
207+
206208
self.batch_size = batch_size
207209
self.world_size = world_size
208210
self.oom_count = 0
@@ -228,6 +230,15 @@ def __init__(
228230
)
229231
self.last_compute_time = -1.0
230232

233+
@property
234+
def trained_batches(self) -> int:
235+
# .trained_batches should return an int
236+
return int(self._trained_batches.item())
237+
238+
@trained_batches.setter
239+
def trained_batches(self, value: int) -> None:
240+
self._trained_batches.fill_(int(value))
241+
231242
def _update_rec_metrics(
232243
self, model_out: Dict[str, torch.Tensor], **kwargs: Any
233244
) -> None:
@@ -260,7 +271,7 @@ def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
260271
self._update_rec_metrics(model_out, **kwargs)
261272
if self.throughput_metric:
262273
self.throughput_metric.update()
263-
self.trained_batches += 1
274+
self._trained_batches.add_(1)
264275

265276
def _adjust_compute_interval(self) -> None:
266277
"""

torchrec/metrics/tests/test_cpu_offloaded_metric_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def test_state_dict_save_load(self) -> None:
341341
"rec_metrics.rec_metrics.0._metrics_computations.0.state_3": torch.tensor(
342342
[6.0]
343343
),
344+
"_trained_batches": torch.tensor(0),
344345
},
345346
)
346347

0 commit comments

Comments
 (0)