@@ -32,9 +32,8 @@ def worker_update_policy(pipe, timeout=5.0):
3232 policy .bias .fill_ (0.0 )
3333
3434 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
35- receiver = scheme .create_receiver ()
36- receiver .register_model (policy )
37- receiver .register_worker_transport (pipe )
35+ scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
36+ receiver = scheme .get_receiver ()
3837
3938 if receiver ._transport .pipe .poll (timeout ):
4039 data , msg = receiver ._transport .pipe .recv ()
@@ -52,9 +51,8 @@ def worker_update_policy_tensordict(pipe, timeout=5.0):
5251 policy .bias .fill_ (0.0 )
5352
5453 scheme = MultiProcessWeightSyncScheme (strategy = "tensordict" )
55- receiver = scheme .create_receiver ()
56- receiver .register_model (policy )
57- receiver .register_worker_transport (pipe )
54+ scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
55+ receiver = scheme .get_receiver ()
5856
5957 if receiver ._transport .pipe .poll (timeout ):
6058 data , msg = receiver ._transport .pipe .recv ()
@@ -192,18 +190,23 @@ def test_cross_format_conversion(self):
192190
193191
194192class TestWeightSyncSchemes :
193+ """Tests for weight sync schemes using the new simplified API.
194+
195+ Lower-level transport and legacy API tests are in TestTransportBackends.
196+ """
197+
195198 def test_multiprocess_scheme_state_dict (self ):
196199 parent_pipe , child_pipe = mp .Pipe ()
197200
198201 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
199- sender = scheme .create_sender ( )
200- sender . register_worker ( 0 , parent_pipe )
202+ scheme .init_on_sender ( model_id = "policy" , pipes = [ parent_pipe ] )
203+ sender = scheme . get_sender ( )
201204
202205 proc = mp .Process (target = worker_update_policy , args = (child_pipe ,))
203206 proc .start ()
204207
205208 weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
206- sender .update_weights (weights )
209+ sender .send (weights )
207210
208211 proc .join (timeout = 10.0 )
209212 assert not proc .is_alive ()
@@ -212,16 +215,16 @@ def test_multiprocess_scheme_tensordict(self):
212215 parent_pipe , child_pipe = mp .Pipe ()
213216
214217 scheme = MultiProcessWeightSyncScheme (strategy = "tensordict" )
215- sender = scheme .create_sender ( )
216- sender . register_worker ( 0 , parent_pipe )
218+ scheme .init_on_sender ( model_id = "policy" , pipes = [ parent_pipe ] )
219+ sender = scheme . get_sender ( )
217220
218221 proc = mp .Process (target = worker_update_policy_tensordict , args = (child_pipe ,))
219222 proc .start ()
220223
221224 weights = TensorDict (
222225 {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}, batch_size = []
223226 )
224- sender .update_weights (weights )
227+ sender .send (weights )
225228
226229 proc .join (timeout = 10.0 )
227230 assert not proc .is_alive ()
@@ -270,6 +273,50 @@ def test_no_weight_sync_scheme(self):
270273 weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
271274 transport .send_weights ("policy" , weights )
272275
276+ def test_receiver_receive_method (self ):
277+ """Test the new non-blocking receive() method."""
278+
279+ def worker_with_receive (pipe ):
280+ policy = nn .Linear (4 , 2 )
281+ with torch .no_grad ():
282+ policy .weight .fill_ (0.0 )
283+ policy .bias .fill_ (0.0 )
284+
285+ scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
286+ scheme .init_on_worker (model_id = "policy" , pipe = pipe , model = policy )
287+ receiver = scheme .get_receiver ()
288+
289+ # Non-blocking receive should return False when no data
290+ result = receiver .receive (timeout = 0.001 )
291+ assert result is False
292+
293+ # Now actually receive the weights
294+ result = receiver .receive (timeout = 5.0 )
295+ assert result is True
296+
297+ # Check weights were applied
298+ return policy .weight .sum ().item (), policy .bias .sum ().item ()
299+
300+ parent_pipe , child_pipe = mp .Pipe ()
301+
302+ scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
303+ scheme .init_on_sender (model_id = "policy" , pipes = [parent_pipe ])
304+ sender = scheme .get_sender ()
305+
306+ proc = mp .Process (target = worker_with_receive , args = (child_pipe ,))
307+ proc .start ()
308+
309+ # Give worker time to call receive with no data
310+ import time
311+
312+ time .sleep (0.1 )
313+
314+ weights = {"weight" : torch .ones (2 , 4 ), "bias" : torch .ones (2 )}
315+ sender .send (weights )
316+
317+ proc .join (timeout = 10.0 )
318+ assert not proc .is_alive ()
319+
273320
274321class TestCollectorIntegration :
275322 @pytest .fixture
0 commit comments