Skip to content

Conversation

@njhill
Copy link
Member

@njhill njhill commented Sep 4, 2025

Follow-on from #23569.

This provides most of the speedup of that PR, like +22% rather than +25%. We could experiment with a slightly more complicated version where the worker runs in a separate thread, but this seems like a good first implementation.

@weijinqian0
Copy link

In RL training scenarios, inference groups are typically managed externally, so 'external launcher' is needed.

@Ronald1995
Copy link
Contributor

This PR is based on top of #23569.

This provides most of the speedup of that PR, like +22% rather than +25%. We could experiment with a slightly more complicated version where the worker runs in a separate thread, but it seems unlikely it would exceed the MP performance anyhow given that the uniproc executor now appears to be no faster than the mulitproc one in general.

Implement async_scheduling in uniproc executor not because it's faster than multiproc, but because we need to use uniproc(external launcher method) in RL training.

@njhill
Copy link
Member Author

njhill commented Sep 5, 2025

@Ronald1995 @weijinqian0 I pushed another commit to also support external launcher executor, maybe you could try it out?

@Ronald1995
Copy link
Contributor

Ronald1995 commented Sep 6, 2025

@Ronald1995 @weijinqian0 I pushed another commit to also support external launcher executor, maybe you could try it out?

@njhill thanks for implementation of exeternal launcher method. i have tested it in my local environment base on your branch, and make some little bugfix, both the performance and precision are validated, would you please just cherry-pick my bugfix commit.

@njhill njhill force-pushed the uniproc-async-sched branch from 9b5e75a to cca2fab Compare September 6, 2025 02:59
@njhill
Copy link
Member Author

njhill commented Sep 6, 2025

@Ronald1995 @weijinqian0 I pushed another commit to also support external launcher executor, maybe you could try it out?

@njhill thanks for implementation of exeternal launcher method. i have tested it in my local environment base on your branch, and make some little bugfix, both the performance and precision are validated, would you please just cherry-pick my bugfix commit.

Thanks @Ronald1995. From your commit:

        if isinstance(outputs, Exception):
            logger.error("EngineCore step failed with error: %s", outputs)
            raise outputs

I don't think step_fn will ever return an exception, do you agree or could you show me why you think this is needed?

@njhill njhill marked this pull request as ready for review September 6, 2025 03:31
@Ronald1995
Copy link
Contributor

@Ronald1995 @weijinqian0 I pushed another commit to also support external launcher executor, maybe you could try it out?

@njhill thanks for implementation of exeternal launcher method. i have tested it in my local environment base on your branch, and make some little bugfix, both the performance and precision are validated, would you please just cherry-pick my bugfix commit.

Thanks @Ronald1995. From your commit:

        if isinstance(outputs, Exception):
            logger.error("EngineCore step failed with error: %s", outputs)
            raise outputs

I don't think step_fn will ever return an exception, do you agree or could you show me why you think this is needed?

@Ronald1995 @weijinqian0 I pushed another commit to also support external launcher executor, maybe you could try it out?

@njhill thanks for implementation of exeternal launcher method. i have tested it in my local environment base on your branch, and make some little bugfix, both the performance and precision are validated, would you please just cherry-pick my bugfix commit.

Thanks @Ronald1995. From your commit:

        if isinstance(outputs, Exception):
            logger.error("EngineCore step failed with error: %s", outputs)
            raise outputs

I don't think step_fn will ever return an exception, do you agree or could you show me why you think this is needed?

i just see get_output method in SyncMPClient handle the exception. but i think you are right, the execute_model_with_error_logging will catch the exception and just raise it, output will not be exception

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 8, 2025
@njhill
Copy link
Member Author

njhill commented Sep 8, 2025

@Ronald1995 any chance you could try out the latest version of this PR with your use case?

I think this is ready apart from extra CI test coverage.

@njhill njhill added the needs-tests Tests needed for this PR label Sep 8, 2025
@Ronald1995
Copy link
Contributor

Ronald1995 commented Sep 9, 2025

@Ronald1995 any chance you could try out the latest version of this PR with your use case?

I think this is ready apart from extra CI test coverage.

OK,I will test the latest version in my local environment. because the GPU resources is occupied by my colleague, i can't test it right now, i will complete this test until tomorrow morning and report the results for you.

@Ronald1995
Copy link
Contributor

@Ronald1995 any chance you could try out the latest version of this PR with your use case?
I think this is ready apart from extra CI test coverage.

OK,I will test the latest version in my local environment. because the GPU resources is occupied by my colleague, i can't test it right now, i will complete this test until tomorrow morning and report the results for you.

@njhill i have tested the latest version with external_launcher scenario, both performance and precision meet expectation.

MengqingCao pushed a commit to vllm-project/vllm-ascend that referenced this pull request Sep 11, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
@mergify mergify bot removed the needs-rebase label Sep 12, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@WoosukKwon WoosukKwon merged commit 4fdd6f5 into vllm-project:main Sep 12, 2025
44 checks passed
@njhill njhill deleted the uniproc-async-sched branch September 12, 2025 23:35
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…4219)

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
…4219)

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Sep 15, 2025
…4219)

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Signed-off-by: bbartels <[email protected]>
offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|

- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…4219)

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…4219)

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…4219)

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…4219)

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
wuxibin89 pushed a commit to volcengine/verl that referenced this pull request Oct 28, 2025
…pc() (#3934)

### What does this PR do?

This PR fixes a `TypeError` that occurs when newer versions of vLLM
(v0.11+) attempt to call
`ExternalZeroMQDistributedExecutor.collective_rpc`.

The issue stems from a recent vLLM update
(vllm-project/vllm#24219) that added the keyword
argument `non_block` to the `Executor.collective_rpc` interface. Since
the `verl` implementation of `collective_rpc` did not define this
parameter, calling it with `non_block=True` resulted in the error:
`TypeError: ExternalZeroMQDistributedExecutor.collective_rpc() got an
unexpected keyword argument 'non_block'`.

By using `**extra_kwargs` in the function signature, we ensure
compatibility with both legacy and modern vLLM interfaces without
affecting the existing ZeroMQ non-blocking logic.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: weikaiwen <[email protected]>
wangboxiong320 pushed a commit to wangboxiong320/verl that referenced this pull request Nov 1, 2025
…pc() (volcengine#3934)

### What does this PR do?

This PR fixes a `TypeError` that occurs when newer versions of vLLM
(v0.11+) attempt to call
`ExternalZeroMQDistributedExecutor.collective_rpc`.

The issue stems from a recent vLLM update
(vllm-project/vllm#24219) that added the keyword
argument `non_block` to the `Executor.collective_rpc` interface. Since
the `verl` implementation of `collective_rpc` did not define this
parameter, calling it with `non_block=True` resulted in the error:
`TypeError: ExternalZeroMQDistributedExecutor.collective_rpc() got an
unexpected keyword argument 'non_block'`.

By using `**extra_kwargs` in the function signature, we ensure
compatibility with both legacy and modern vLLM interfaces without
affecting the existing ZeroMQ non-blocking logic.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: weikaiwen <[email protected]>
NenoL2001 pushed a commit to NenoL2001/verl that referenced this pull request Nov 3, 2025
…pc() (volcengine#3934)

### What does this PR do?

This PR fixes a `TypeError` that occurs when newer versions of vLLM
(v0.11+) attempt to call
`ExternalZeroMQDistributedExecutor.collective_rpc`.

The issue stems from a recent vLLM update
(vllm-project/vllm#24219) that added the keyword
argument `non_block` to the `Executor.collective_rpc` interface. Since
the `verl` implementation of `collective_rpc` did not define this
parameter, calling it with `non_block=True` resulted in the error:
`TypeError: ExternalZeroMQDistributedExecutor.collective_rpc() got an
unexpected keyword argument 'non_block'`.

By using `**extra_kwargs` in the function signature, we ensure
compatibility with both legacy and modern vLLM interfaces without
affecting the existing ZeroMQ non-blocking logic.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: weikaiwen <[email protected]>
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
…pc() (volcengine#3934)

### What does this PR do?

This PR fixes a `TypeError` that occurs when newer versions of vLLM
(v0.11+) attempt to call
`ExternalZeroMQDistributedExecutor.collective_rpc`.

The issue stems from a recent vLLM update
(vllm-project/vllm#24219) that added the keyword
argument `non_block` to the `Executor.collective_rpc` interface. Since
the `verl` implementation of `collective_rpc` did not define this
parameter, calling it with `non_block=True` resulted in the error:
`TypeError: ExternalZeroMQDistributedExecutor.collective_rpc() got an
unexpected keyword argument 'non_block'`.

By using `**extra_kwargs` in the function signature, we ensure
compatibility with both legacy and modern vLLM interfaces without
affecting the existing ZeroMQ non-blocking logic.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: weikaiwen <[email protected]>
chenhaiq pushed a commit to The-Hierophant/verl-1 that referenced this pull request Nov 18, 2025
…pc() (volcengine#3934)

### What does this PR do?

This PR fixes a `TypeError` that occurs when newer versions of vLLM
(v0.11+) attempt to call
`ExternalZeroMQDistributedExecutor.collective_rpc`.

The issue stems from a recent vLLM update
(vllm-project/vllm#24219) that added the keyword
argument `non_block` to the `Executor.collective_rpc` interface. Since
the `verl` implementation of `collective_rpc` did not define this
parameter, calling it with `non_block=True` resulted in the error:
`TypeError: ExternalZeroMQDistributedExecutor.collective_rpc() got an
unexpected keyword argument 'non_block'`.

By using `**extra_kwargs` in the function signature, we ensure
compatibility with both legacy and modern vLLM interfaces without
affecting the existing ZeroMQ non-blocking logic.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: weikaiwen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-tests Tests needed for this PR ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants