Skip to content

Commit bb3fb15

Browse files
committed
[Refactor] Weight sync schemes refactor
ghstack-source-id: 902cff2 Pull-Request: #3230
1 parent 56990b9 commit bb3fb15

File tree

3 files changed

+937
-206
lines changed

3 files changed

+937
-206
lines changed

test/test_weightsync.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ def worker_update_policy(pipe, timeout=5.0):
3232
policy.bias.fill_(0.0)
3333

3434
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
35-
receiver = scheme.create_receiver()
36-
receiver.register_model(policy)
37-
receiver.register_worker_transport(pipe)
35+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
36+
receiver = scheme.get_receiver()
3837

3938
if receiver._transport.pipe.poll(timeout):
4039
data, msg = receiver._transport.pipe.recv()
@@ -52,9 +51,8 @@ def worker_update_policy_tensordict(pipe, timeout=5.0):
5251
policy.bias.fill_(0.0)
5352

5453
scheme = MultiProcessWeightSyncScheme(strategy="tensordict")
55-
receiver = scheme.create_receiver()
56-
receiver.register_model(policy)
57-
receiver.register_worker_transport(pipe)
54+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
55+
receiver = scheme.get_receiver()
5856

5957
if receiver._transport.pipe.poll(timeout):
6058
data, msg = receiver._transport.pipe.recv()
@@ -192,18 +190,23 @@ def test_cross_format_conversion(self):
192190

193191

194192
class TestWeightSyncSchemes:
193+
"""Tests for weight sync schemes using the new simplified API.
194+
195+
Lower-level transport and legacy API tests are in TestTransportBackends.
196+
"""
197+
195198
def test_multiprocess_scheme_state_dict(self):
196199
parent_pipe, child_pipe = mp.Pipe()
197200

198201
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
199-
sender = scheme.create_sender()
200-
sender.register_worker(0, parent_pipe)
202+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
203+
sender = scheme.get_sender()
201204

202205
proc = mp.Process(target=worker_update_policy, args=(child_pipe,))
203206
proc.start()
204207

205208
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
206-
sender.update_weights(weights)
209+
sender.send(weights)
207210

208211
proc.join(timeout=10.0)
209212
assert not proc.is_alive()
@@ -212,16 +215,16 @@ def test_multiprocess_scheme_tensordict(self):
212215
parent_pipe, child_pipe = mp.Pipe()
213216

214217
scheme = MultiProcessWeightSyncScheme(strategy="tensordict")
215-
sender = scheme.create_sender()
216-
sender.register_worker(0, parent_pipe)
218+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
219+
sender = scheme.get_sender()
217220

218221
proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,))
219222
proc.start()
220223

221224
weights = TensorDict(
222225
{"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[]
223226
)
224-
sender.update_weights(weights)
227+
sender.send(weights)
225228

226229
proc.join(timeout=10.0)
227230
assert not proc.is_alive()
@@ -270,6 +273,50 @@ def test_no_weight_sync_scheme(self):
270273
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
271274
transport.send_weights("policy", weights)
272275

276+
def test_receiver_receive_method(self):
277+
"""Test the new non-blocking receive() method."""
278+
279+
def worker_with_receive(pipe):
280+
policy = nn.Linear(4, 2)
281+
with torch.no_grad():
282+
policy.weight.fill_(0.0)
283+
policy.bias.fill_(0.0)
284+
285+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
286+
scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy)
287+
receiver = scheme.get_receiver()
288+
289+
# Non-blocking receive should return False when no data
290+
result = receiver.receive(timeout=0.001)
291+
assert result is False
292+
293+
# Now actually receive the weights
294+
result = receiver.receive(timeout=5.0)
295+
assert result is True
296+
297+
# Check weights were applied
298+
return policy.weight.sum().item(), policy.bias.sum().item()
299+
300+
parent_pipe, child_pipe = mp.Pipe()
301+
302+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
303+
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])
304+
sender = scheme.get_sender()
305+
306+
proc = mp.Process(target=worker_with_receive, args=(child_pipe,))
307+
proc.start()
308+
309+
# Give worker time to call receive with no data
310+
import time
311+
312+
time.sleep(0.1)
313+
314+
weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)}
315+
sender.send(weights)
316+
317+
proc.join(timeout=10.0)
318+
assert not proc.is_alive()
319+
273320

274321
class TestCollectorIntegration:
275322
@pytest.fixture

0 commit comments

Comments
 (0)