Skip to content

Commit c4fb543

Browse files
committed
clean up managed work
1 parent 22b8fa1 commit c4fb543

File tree

6 files changed

+351
-43
lines changed

6 files changed

+351
-43
lines changed

torchft/_test/managed_work_test.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import types
8+
import unittest
9+
from datetime import timedelta
10+
from typing import Callable, List, Optional
11+
12+
import parameterized
13+
import torch
14+
from torch.distributed.distributed_c10d import Work
15+
from torch.futures import Future
16+
17+
from torchft.manager import Manager, _ManagedWork
18+
19+
20+
class SimpleWork(Work):
21+
"""A simple implementation of torch.distributed.Work for testing."""
22+
23+
def __init__(self, tensors: List[torch.Tensor]) -> None:
24+
super().__init__()
25+
self._tensors = tensors
26+
self._future: Future[List[torch.Tensor]] = torch.futures.Future()
27+
self._is_completed: bool = False
28+
29+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
30+
self._is_completed = True
31+
self._future.set_result(self._tensors)
32+
return True
33+
34+
def get_future(self) -> Future[List[torch.Tensor]]:
35+
return self._future
36+
37+
38+
class TestManagedWork(unittest.TestCase):
39+
@parameterized.parameterized.expand(
40+
[
41+
("cpu", torch.device("cpu")),
42+
("cuda", torch.device("cuda:0")),
43+
]
44+
)
45+
def test_callbacks_execute_after_wait(
46+
self, name: str, device: torch.device
47+
) -> None:
48+
"""Test that callbacks are only executed after wait() is called."""
49+
# Skip if CUDA is requested but not available
50+
if device.type == "cuda" and not torch.cuda.is_available():
51+
self.skipTest("CUDA not available")
52+
53+
# Create a tensor to work with
54+
tensor = torch.ones(1, dtype=torch.float32, device=device)
55+
56+
# Create a simple work object
57+
work = SimpleWork([tensor])
58+
59+
# Create a minimal manager object with just the wrap_future method
60+
manager = Manager.__new__(Manager) # Create instance without calling __init__
61+
# We're using types.MethodType to attach a method to the manager instance
62+
# This is just for testing purposes
63+
manager.wrap_future = types.MethodType( # type: ignore
64+
lambda self, fut, default, timeout=None: fut, manager
65+
)
66+
67+
# Create the managed work
68+
managed_work = _ManagedWork(work, manager, [tensor])
69+
70+
# Track callback execution
71+
callback_executed: bool = False
72+
73+
def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
74+
nonlocal callback_executed
75+
callback_executed = True
76+
# Multiply tensor by 2 to verify the callback ran
77+
fut.value()[0].mul_(2)
78+
return fut.value()
79+
80+
# Add the callback
81+
managed_work.add_callback(callback)
82+
83+
# Verify callback hasn't executed yet
84+
self.assertFalse(callback_executed)
85+
self.assertEqual(tensor.item(), 1.0)
86+
87+
# Call wait() which should trigger the callback
88+
managed_work.wait()
89+
90+
# Verify callback has executed
91+
self.assertTrue(callback_executed)
92+
self.assertEqual(tensor.item(), 2.0)
93+
94+
@parameterized.parameterized.expand(
95+
[
96+
("cpu", torch.device("cpu")),
97+
("cuda", torch.device("cuda:0")),
98+
]
99+
)
100+
def test_multiple_callbacks_execute_in_order(
101+
self, name: str, device: torch.device
102+
) -> None:
103+
"""Test that multiple callbacks are executed in the order they were added."""
104+
# Skip if CUDA is requested but not available
105+
if device.type == "cuda" and not torch.cuda.is_available():
106+
self.skipTest("CUDA not available")
107+
108+
# Create a tensor to work with
109+
tensor = torch.ones(1, dtype=torch.float32, device=device)
110+
111+
# Create a simple work object
112+
work = SimpleWork([tensor])
113+
114+
# Create a minimal manager object with just the wrap_future method
115+
manager = Manager.__new__(Manager) # Create instance without calling __init__
116+
manager.wrap_future = types.MethodType( # type: ignore
117+
lambda self, fut, default, timeout=None: fut, manager
118+
)
119+
120+
# Create the managed work
121+
managed_work = _ManagedWork(work, manager, [tensor])
122+
123+
# Track execution order
124+
execution_order: List[int] = []
125+
126+
def callback1(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
127+
execution_order.append(1)
128+
fut.value()[0].add_(1)
129+
return fut.value()
130+
131+
def callback2(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
132+
execution_order.append(2)
133+
fut.value()[0].add_(2)
134+
return fut.value()
135+
136+
def callback3(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
137+
execution_order.append(3)
138+
fut.value()[0].add_(3)
139+
return fut.value()
140+
141+
# Add callbacks
142+
managed_work.add_callback(callback1)
143+
managed_work.add_callback(callback2)
144+
managed_work.add_callback(callback3)
145+
146+
# Verify no callbacks have executed yet
147+
self.assertEqual(len(execution_order), 0)
148+
self.assertEqual(tensor.item(), 1.0)
149+
150+
# Call wait() which should trigger the callbacks
151+
managed_work.wait()
152+
153+
# Verify callbacks executed in order
154+
self.assertEqual(execution_order, [1, 2, 3])
155+
156+
# Each callback adds to the tensor, so final value should be 1 + 1 + 2 + 3 = 7
157+
self.assertEqual(tensor.item(), 7.0)
158+
159+
@parameterized.parameterized.expand(
160+
[
161+
("cpu", torch.device("cpu")),
162+
("cuda", torch.device("cuda:0")),
163+
]
164+
)
165+
def test_future_then_api(self, name: str, device: torch.device) -> None:
166+
"""Test that the future's then API works correctly with ManagedWork."""
167+
# Skip if CUDA is requested but not available
168+
if device.type == "cuda" and not torch.cuda.is_available():
169+
self.skipTest("CUDA not available")
170+
171+
# Create a tensor to work with
172+
tensor = torch.ones(1, dtype=torch.float32, device=device)
173+
174+
# Create a simple work object
175+
work = SimpleWork([tensor])
176+
177+
# Create a minimal manager object with just the wrap_future method
178+
manager = Manager.__new__(Manager) # Create instance without calling __init__
179+
manager.wrap_future = types.MethodType( # type: ignore
180+
lambda self, fut, default, timeout=None: fut, manager
181+
)
182+
183+
# Create the managed work
184+
managed_work = _ManagedWork(work, manager, [tensor])
185+
186+
# Get the future
187+
future = managed_work.get_future()
188+
189+
# Track callback execution
190+
callback_executed: bool = False
191+
192+
def callback(fut: Future[List[torch.Tensor]]) -> List[torch.Tensor]:
193+
nonlocal callback_executed
194+
callback_executed = True
195+
# Multiply tensor by 3 to verify the callback ran
196+
fut.value()[0].mul_(3)
197+
return fut.value()
198+
199+
# Use the then API
200+
future.then(callback)
201+
202+
# Verify callback hasn't executed yet
203+
self.assertFalse(callback_executed)
204+
self.assertEqual(tensor.item(), 1.0)
205+
206+
# Call wait() which should trigger the callback
207+
future.wait()
208+
209+
# Verify callback has executed
210+
self.assertTrue(callback_executed)
211+
self.assertEqual(tensor.item(), 3.0)
212+
213+
214+
if __name__ == "__main__":
215+
unittest.main()

torchft/ddp.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,22 @@ def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
7171
work = state.allreduce(bucket.buffer())
72-
work.synchronize()
73-
return work.get_future()
72+
73+
result_fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
74+
75+
fut = work.get_future()
76+
77+
def callback(
78+
tensors: torch.futures.Future[list[torch.Tensor]],
79+
) -> list[torch.Tensor]:
80+
nonlocal result_fut
81+
result_fut.set_result(tensors.value()[0])
82+
return []
83+
84+
fut = fut.then(callback)
85+
86+
work.wait()
87+
return result_fut
7488

7589

7690
class PureDistributedDataParallel(nn.Module):

torchft/local_sgd.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,9 @@ def _bucketize_and_allreduce(
519519
flat_buffer, should_quantize=self.should_quantize
520520
)
521521

522-
def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
522+
def callback(
523+
fut: torch.futures.Future[list[torch.Tensor]],
524+
) -> list[torch.Tensor]:
523525
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
524526
nonlocal bucket_tensors, flat_buffer
525527
# Setup stream dependency
@@ -529,9 +531,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
529531
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
530532
)
531533

532-
work.synchronize()
534+
return []
535+
533536
fut = work.get_future()
534-
fut.add_done_callback(callback)
537+
fut = fut.then(callback)
535538

536539
self._allreduce_work.append(work)
537540

0 commit comments

Comments
 (0)