Skip to content

fp8 broadcast#1981

Open
S1ro1 wants to merge 6 commits intomainfrom
weight-transfer
Open

fp8 broadcast#1981
S1ro1 wants to merge 6 commits intomainfrom
weight-transfer

Conversation

@S1ro1
Copy link
Collaborator

@S1ro1 S1ro1 commented Mar 7, 2026

Note

High Risk
Touches the training↔inference weight-update pipeline (NCCL broadcast, in-place parameter updates, and optional FP8 quantization), where shape/format mismatches can break live inference or silently degrade quality. Also changes multi-node SLURM orchestration and dependency pinning, which can affect cluster runs and reproducibility.

Overview
Adds an opt-in vLLM kernel-format weight broadcast path for NCCL, including optional block-wise FP8 quantization, to enable direct in-place updates on inference workers without converting through HF checkpoint format.

Plumbs new config flags (e.g. use_vllm_format_transfer, quantize_fp8) through shared/trainer/orchestrator configs and the /init_broadcaster RPC; updates the vLLM NCCL worker to copy_() received params (with EP expert slicing and MLA absorbed-weight recomputation) and introduces model-side layer conversion for glm_moe_dsa.

Extends multi-node RL deployment to support multiple inference replicas (num_infer_replicas/total_infer_nodes) and updates the SLURM template for per-replica head selection and headless nodes; also relaxes inference api_server_count to allow 0 for headless mode and adjusts model-name propagation/validation so orchestrator matches inference.

Written by Cursor Bugbot for commit d685dfc. This will update automatically on new commits. Configure here.

[[tool.uv.index]]
name = "vllm-nightly"
url = "https://wheels.vllm.ai/nightly"
url = "https://download.pytorch.org/whl/test/cu128"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch sourced from test/pre-release channel instead of stable

High Severity

The PyTorch index URL was changed from https://download.pytorch.org/whl/cu128 (stable releases) to https://download.pytorch.org/whl/test/cu128 (pre-release/test builds). This causes the project to install torch 2.9.1+cu128 from the test channel instead of a stable release. Test channel builds may contain regressions or breaking changes that haven't been validated for production use.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

description="Quantize weights to FP8 (e4m3) with block-wise scaling during kernel format transfer. "
"Only used when use_kernel_format_transfer is True."
),
] = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Config changes missing CHANGELOG.md update

Low Severity

Multiple new config fields were added across src/prime_rl/configs/inference.py (data_parallel_address, data_parallel_start_rank, headless), src/prime_rl/configs/rl.py (use_kernel_format_transfer, quantize_fp8, allow_different_inference_model), src/prime_rl/configs/trainer.py (use_kernel_format_transfer, quantize_fp8), and src/prime_rl/configs/orchestrator.py (use_kernel_format_transfer), but CHANGELOG.md was not updated with entries for any of these new fields.

Additional Locations (2)

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

logger.error(f"Kernel weight transfer: {len(shape_mismatches)} SHAPE MISMATCHES: {shape_mismatches}")
if skipped:
logger.warning(f"Kernel weight transfer: {len(skipped)} skipped (not in model): {skipped}")
logger.info(f"Kernel weight transfer: copied {loaded} weights in-place")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernel format loader only checks parameters, missing buffers

Medium Severity

_load_kernel_format builds its lookup dict from model.named_parameters(), which excludes buffers. When quantize_fp8 is enabled on the sender, weight_scale_inv tensors are broadcast alongside FP8 weights. If the receiving vLLM model stores these scale tensors as buffers rather than parameters, they won't be matched and will be silently skipped, leaving stale scale factors that cause incorrect dequantization during inference.

Fix in Cursor Fix in Web

Comment on lines +229 to +249
data_parallel_address: Annotated[
str | None,
Field(
description="Address for cross-node data parallel communication. Passed to vLLM as `--data-parallel-address`.",
),
] = None

data_parallel_start_rank: Annotated[
int | None,
Field(
ge=0,
description="Starting DP rank for this node in multi-node EP. Passed to vLLM as `--data-parallel-start-rank`.",
),
] = None

headless: Annotated[
bool,
Field(
description="Run in headless mode (no API server). Passed to vLLM as `--headless`.",
),
] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can pass all of this via vllm_extras instead imo

Comment on lines +120 to +126
use_kernel_format_transfer: Annotated[
bool,
Field(
description="Transfer weights in vLLM kernel format instead of HF checkpoint format. "
"Avoids the HF conversion intermediate step and allows direct in-place weight updates."
),
] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_kernel_format_transfer: Annotated[
bool,
Field(
description="Transfer weights in vLLM kernel format instead of HF checkpoint format. "
"Avoids the HF conversion intermediate step and allows direct in-place weight updates."
),
] = False
use_vllm_format_transfer: Annotated[
bool,
Field(
description="Transfer weights in vLLM kernel format instead of HF checkpoint format. "
"Avoids the HF conversion intermediate step and allows direct in-place weight updates."
),
] = False

lets rename to this

port: Annotated[int, Field(description="The port to use for the NCCL broadcast.")] = 29501
timeout: Annotated[int, Field(description="The timeout in seconds to use for the NCCL broadcast.")] = 1200

use_kernel_format_transfer: Annotated[
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this to vllm broadcast config instead of having it on the orch ?

Comment on lines +299 to +307
allow_different_inference_model: Annotated[
bool,
Field(
description="Allow the inference server to use a different model name than the trainer. "
"When enabled, the orchestrator uses the inference model name for querying. "
"Useful for kernel format weight transfer where the trainer uses a bf16 model "
"and inference uses a quantized (e.g. FP8) variant.",
),
] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that mean we always need to load the fp8 model on the vllm side ?

I think we should just allow this by default and not have this param imo, should still translate from model.name to trainer.model.name and infer.model.name but user should be able to override it via trainer.model.name basically

Comment on lines +115 to +117
--data_parallel_start_rank $INFER_DP_START_RANK \
--data_parallel_address $INFER_HEAD_HOST \
--data_parallel_rpc_port $INFERENCE_DATA_PARALLEL_RPC_PORT \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be always enabled ?

@samsja samsja changed the title checkpoint fp8 broadcast Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants