Skip to content

Commit adce4f1

Browse files
committed
Encoder separation for Encode-Prefill-Decode Disaggregation
Signed-off-by: amy-why-3459 <[email protected]>
1 parent 366d2d9 commit adce4f1

File tree

4 files changed

+73
-11
lines changed

4 files changed

+73
-11
lines changed

vllm_ascend/patch/platform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import vllm_ascend.patch.platform.patch_config # noqa
2020
import vllm_ascend.patch.platform.patch_distributed # noqa
2121
import vllm_ascend.patch.platform.patch_dynamo_vllm_backend # noqa
22+
import vllm_ascend.patch.platform.patch_ec_connector # noqa
2223
import vllm_ascend.patch.platform.patch_mamba_config # noqa
2324
import vllm_ascend.patch.platform.patch_sched_yield # noqa
2425

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector
2+
from safetensors.torch import load_file
3+
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
4+
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
5+
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
6+
from vllm.logger import logger
7+
8+
9+
class AscendECSharedStorageConnector(ECSharedStorageConnector):
10+
11+
def start_load_caches(self, encoder_cache, **kwargs) -> None:
12+
metadata: ECConnectorMetadata = self._get_connector_metadata()
13+
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
14+
assert encoder_cache is not None
15+
if metadata is None:
16+
logger.warning((
17+
"In connector.start_load_caches, ",
18+
"but the connector metadata is None",
19+
))
20+
return
21+
# Load the EC for each mm data
22+
for mm_data in metadata.mm_datas:
23+
if mm_data.mm_hash in encoder_cache:
24+
continue
25+
filename = self._generate_filename_debug(mm_data.mm_hash)
26+
ec_cache = load_file(filename)["ec_cache"].npu()
27+
encoder_cache[mm_data.mm_hash] = ec_cache
28+
logger.debug("Success load encoder cache for hash %s",
29+
mm_data.mm_hash)
30+
31+
32+
vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector

vllm_ascend/worker/model_runner_v1.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
4848
get_layers_from_vllm_config)
4949
from vllm.distributed import tensor_model_parallel_all_gather
50+
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
5051
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
5152
has_kv_transfer_group)
5253
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
@@ -91,12 +92,15 @@
9192
UniformTypeKVCacheSpecs)
9293
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
9394
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
94-
PoolerOutput)
95+
PoolerOutput,
96+
make_empty_encoder_model_runner_output)
9597
from vllm.v1.pool.metadata import PoolingMetadata
9698
from vllm.v1.sample.metadata import SamplingMetadata
9799
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
98100
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
99101
from vllm.v1.utils import CpuGpuBuffer
102+
from vllm.v1.worker.ec_connector_model_runner_mixin import \
103+
ECConnectorModelRunnerMixin
100104
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
101105
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
102106
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
@@ -269,7 +273,7 @@ class ExecuteModelState(NamedTuple):
269273
positions: torch.Tensor
270274

271275

272-
class NPUModelRunner(LoRAModelRunnerMixin):
276+
class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
273277

274278
def __init__(self, vllm_config: VllmConfig, device: torch.device):
275279
self.vllm_config = vllm_config
@@ -792,6 +796,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
792796

793797
req_ids_to_add.append(req_id)
794798

799+
# If this rank is an EC transfer producer,
800+
# skip updating the states of KV cache blocks.
801+
if has_ec_transfer() and get_ec_transfer().is_producer:
802+
return
803+
795804
# Update the states of the running/resumed requests.
796805
is_last_rank = get_pp_group().is_last_rank
797806
req_data = scheduler_output.scheduled_cached_reqs
@@ -1094,6 +1103,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
10941103
output,
10951104
is_embed=pos_info.is_embed,
10961105
)
1106+
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
10971107

10981108
def _batch_mm_kwargs_from_scheduler(
10991109
self,
@@ -1620,15 +1630,19 @@ def _prepare_inputs(
16201630
# _prepare_inputs may reorder the batch, so we must gather
16211631
# multi-modal outputs after that to ensure the correct order
16221632
if self.is_multimodal_model:
1623-
# Run the multimodal encoder if any.
1624-
self._execute_mm_encoder(scheduler_output)
1625-
1626-
# NOTE(woosuk): To unify token ids and soft tokens (vision
1627-
# embeddings), we always use embeddings (rather than token ids)
1628-
# as input to the multimodal model, even when the input is text.
1629-
input_ids = self.input_ids[:total_num_scheduled_tokens]
1630-
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
1631-
scheduler_output)
1633+
with self.maybe_get_ec_connector_output(
1634+
scheduler_output,
1635+
encoder_cache=self.encoder_cache,
1636+
):
1637+
# Run the multimodal encoder if any.
1638+
self._execute_mm_encoder(scheduler_output)
1639+
1640+
# NOTE(woosuk): To unify token ids and soft tokens (vision
1641+
# embeddings), we always use embeddings (rather than token ids)
1642+
# as input to the multimodal model, even when the input is text.
1643+
input_ids = self.input_ids[:total_num_scheduled_tokens]
1644+
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
1645+
scheduler_output)
16321646

16331647
inputs_embeds = self.model.embed_input_ids(
16341648
input_ids,
@@ -2272,6 +2286,15 @@ def execute_model(
22722286

22732287
with ProfileExecuteDuration().capture_async("prepare input"):
22742288
self._update_states(scheduler_output)
2289+
if has_ec_transfer() and get_ec_transfer().is_producer:
2290+
with self.maybe_get_ec_connector_output(
2291+
scheduler_output,
2292+
encoder_cache=self.encoder_cache,
2293+
):
2294+
self._execute_mm_encoder(scheduler_output)
2295+
return make_empty_encoder_model_runner_output(
2296+
scheduler_output)
2297+
22752298
if not scheduler_output.total_num_scheduled_tokens:
22762299
if not has_kv_transfer_group():
22772300
logger.debug(
@@ -3769,6 +3792,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
37693792
KVCacheSpec: A dictionary mapping layer names to their KV cache
37703793
format. Layers that do not need KV cache are not included.
37713794
"""
3795+
3796+
if has_ec_transfer() and get_ec_transfer().is_producer:
3797+
return {}
3798+
37723799
block_size = self.vllm_config.cache_config.block_size
37733800
use_mla = self.vllm_config.model_config.use_mla
37743801
kv_cache_spec: dict[str, KVCacheSpec] = {}

vllm_ascend/worker/worker_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.config import VllmConfig
3131
from vllm.distributed import (ensure_model_parallel_initialized,
3232
init_distributed_environment)
33+
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
3334
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
3435
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
3536
from vllm.logger import logger
@@ -426,6 +427,7 @@ def _init_worker_distributed_environment(self) -> None:
426427
self.parallel_config.decode_context_parallel_size)
427428
init_ascend_model_parallel(self.parallel_config)
428429
ensure_kv_transfer_initialized(self.vllm_config)
430+
ensure_ec_transfer_initialized(self.vllm_config)
429431

430432
def _init_profiler(self):
431433
# Torch profiler. Enabled and configured through env vars:

0 commit comments

Comments
 (0)