diff --git a/tests/test_averaging.py b/tests/test_averaging.py index 7a6d37dfa..776b45747 100644 --- a/tests/test_averaging.py +++ b/tests/test_averaging.py @@ -521,7 +521,9 @@ def test_averaging_trigger(): @pytest.mark.forked -def test_averaging_cancel(): +@pytest.mark.parametrize("target_group_size", [None, 2]) +def test_averaging_cancel(target_group_size): + dht_instances = launch_dht_instances(4) averagers = tuple( DecentralizedAverager( averaged_tensors=[torch.randn(3)], @@ -529,23 +531,35 @@ def test_averaging_cancel(): min_matchmaking_time=0.5, request_timeout=0.3, client_mode=(i % 2 == 0), + target_group_size=target_group_size, prefix="mygroup", start=True, ) - for i, dht in enumerate(launch_dht_instances(4)) + for i, dht in enumerate(dht_instances) ) - step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers] + step_controls = [averager.step(wait=False, require_trigger=True) for averager in averagers] + + peer_inds_to_cancel = (0, 1) + + for peer_index in peer_inds_to_cancel: + step_controls[peer_index].cancel() time.sleep(0.05) - step_controls[0].cancel() - step_controls[1].cancel() for i, control in enumerate(step_controls): - if i in (0, 1): + if i not in peer_inds_to_cancel: + control.allow_allreduce() + + for i, control in enumerate(step_controls): + if i in peer_inds_to_cancel: assert control.cancelled() else: - assert control.result() is not None and len(control.result()) == 2 + result = control.result() + assert result is not None + # Don't check group size when target_group_size=None, as it could change + if target_group_size is not None: + assert len(result) == target_group_size for averager in averagers: averager.shutdown()