Skip to content

Commit 00a7b69

Browse files
prajjwal1facebook-github-bot
authored andcommitted
fix RecMetrics loading (make trained_batches a buffer) (#3534)
Summary: This diff addresses the following task: T209753398 Currently `trained_batches` is not stored in state_dict, requiring us to manually sync this variable upon checkpoint loading. We make this variable a buffer so that it can now be captured with model state dict. Differential Revision: D86697665
1 parent d6ee3e0 commit 00a7b69

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

torchrec/metrics/metric_module.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,10 @@ def __init__(
202202
self.rec_metrics = rec_metrics if rec_metrics else RecMetricList([])
203203
self.throughput_metric = throughput_metric
204204
self.state_metrics = state_metrics if state_metrics else {}
205-
self.trained_batches: int = 0
205+
206+
trained_batches = torch.tensor(0, dtype=torch.int64)
207+
self.register_buffer("_trained_batches", trained_batches, persistent=True)
208+
206209
self.batch_size = batch_size
207210
self.world_size = world_size
208211
self.oom_count = 0
@@ -228,6 +231,15 @@ def __init__(
228231
)
229232
self.last_compute_time = -1.0
230233

234+
@property
235+
def trained_batches(self) -> int:
236+
# .trained_batches should return an int
237+
return int(self._trained_batches.item())
238+
239+
@trained_batches.setter
240+
def trained_batches(self, value: int) -> None:
241+
self._trained_batches.fill_(int(value))
242+
231243
def _update_rec_metrics(
232244
self, model_out: Dict[str, torch.Tensor], **kwargs: Any
233245
) -> None:
@@ -260,7 +272,7 @@ def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
260272
self._update_rec_metrics(model_out, **kwargs)
261273
if self.throughput_metric:
262274
self.throughput_metric.update()
263-
self.trained_batches += 1
275+
self._trained_batches.add_(1)
264276

265277
def _adjust_compute_interval(self) -> None:
266278
"""

torchrec/metrics/tests/test_cpu_offloaded_metric_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def test_state_dict_save_load(self) -> None:
341341
"rec_metrics.rec_metrics.0._metrics_computations.0.state_3": torch.tensor(
342342
[6.0]
343343
),
344+
"_trained_batches": torch.tensor([0], dtype=torch.int64),
344345
},
345346
)
346347

0 commit comments

Comments
 (0)