|
47 | 47 | from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, |
48 | 48 | get_layers_from_vllm_config) |
49 | 49 | from vllm.distributed import tensor_model_parallel_all_gather |
| 50 | +from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer |
50 | 51 | from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
51 | 52 | has_kv_transfer_group) |
52 | 53 | from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 |
|
91 | 92 | UniformTypeKVCacheSpecs) |
92 | 93 | from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, |
93 | 94 | DraftTokenIds, LogprobsTensors, ModelRunnerOutput, |
94 | | - PoolerOutput) |
| 95 | + PoolerOutput, |
| 96 | + make_empty_encoder_model_runner_output) |
95 | 97 | from vllm.v1.pool.metadata import PoolingMetadata |
96 | 98 | from vllm.v1.sample.metadata import SamplingMetadata |
97 | 99 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
98 | 100 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
99 | 101 | from vllm.v1.utils import CpuGpuBuffer |
| 102 | +from vllm.v1.worker.ec_connector_model_runner_mixin import \ |
| 103 | + ECConnectorModelRunnerMixin |
100 | 104 | from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput |
101 | 105 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
102 | 106 | from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, |
@@ -269,7 +273,7 @@ class ExecuteModelState(NamedTuple): |
269 | 273 | positions: torch.Tensor |
270 | 274 |
|
271 | 275 |
|
272 | | -class NPUModelRunner(LoRAModelRunnerMixin): |
| 276 | +class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): |
273 | 277 |
|
274 | 278 | def __init__(self, vllm_config: VllmConfig, device: torch.device): |
275 | 279 | self.vllm_config = vllm_config |
@@ -792,6 +796,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: |
792 | 796 |
|
793 | 797 | req_ids_to_add.append(req_id) |
794 | 798 |
|
| 799 | + # If this rank is an EC transfer producer, |
| 800 | + # skip updating the states of KV cache blocks. |
| 801 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 802 | + return |
| 803 | + |
795 | 804 | # Update the states of the running/resumed requests. |
796 | 805 | is_last_rank = get_pp_group().is_last_rank |
797 | 806 | req_data = scheduler_output.scheduled_cached_reqs |
@@ -1094,6 +1103,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): |
1094 | 1103 | output, |
1095 | 1104 | is_embed=pos_info.is_embed, |
1096 | 1105 | ) |
| 1106 | + self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) |
1097 | 1107 |
|
1098 | 1108 | def _batch_mm_kwargs_from_scheduler( |
1099 | 1109 | self, |
@@ -1620,15 +1630,19 @@ def _prepare_inputs( |
1620 | 1630 | # _prepare_inputs may reorder the batch, so we must gather |
1621 | 1631 | # multi-modal outputs after that to ensure the correct order |
1622 | 1632 | if self.is_multimodal_model: |
1623 | | - # Run the multimodal encoder if any. |
1624 | | - self._execute_mm_encoder(scheduler_output) |
1625 | | - |
1626 | | - # NOTE(woosuk): To unify token ids and soft tokens (vision |
1627 | | - # embeddings), we always use embeddings (rather than token ids) |
1628 | | - # as input to the multimodal model, even when the input is text. |
1629 | | - input_ids = self.input_ids[:total_num_scheduled_tokens] |
1630 | | - mm_embeds, is_mm_embed = self._gather_mm_embeddings( |
1631 | | - scheduler_output) |
| 1633 | + with self.maybe_get_ec_connector_output( |
| 1634 | + scheduler_output, |
| 1635 | + encoder_cache=self.encoder_cache, |
| 1636 | + ): |
| 1637 | + # Run the multimodal encoder if any. |
| 1638 | + self._execute_mm_encoder(scheduler_output) |
| 1639 | + |
| 1640 | + # NOTE(woosuk): To unify token ids and soft tokens (vision |
| 1641 | + # embeddings), we always use embeddings (rather than token ids) |
| 1642 | + # as input to the multimodal model, even when the input is text. |
| 1643 | + input_ids = self.input_ids[:total_num_scheduled_tokens] |
| 1644 | + mm_embeds, is_mm_embed = self._gather_mm_embeddings( |
| 1645 | + scheduler_output) |
1632 | 1646 |
|
1633 | 1647 | inputs_embeds = self.model.embed_input_ids( |
1634 | 1648 | input_ids, |
@@ -2272,6 +2286,15 @@ def execute_model( |
2272 | 2286 |
|
2273 | 2287 | with ProfileExecuteDuration().capture_async("prepare input"): |
2274 | 2288 | self._update_states(scheduler_output) |
| 2289 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 2290 | + with self.maybe_get_ec_connector_output( |
| 2291 | + scheduler_output, |
| 2292 | + encoder_cache=self.encoder_cache, |
| 2293 | + ): |
| 2294 | + self._execute_mm_encoder(scheduler_output) |
| 2295 | + return make_empty_encoder_model_runner_output( |
| 2296 | + scheduler_output) |
| 2297 | + |
2275 | 2298 | if not scheduler_output.total_num_scheduled_tokens: |
2276 | 2299 | if not has_kv_transfer_group(): |
2277 | 2300 | logger.debug( |
@@ -3769,6 +3792,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: |
3769 | 3792 | KVCacheSpec: A dictionary mapping layer names to their KV cache |
3770 | 3793 | format. Layers that do not need KV cache are not included. |
3771 | 3794 | """ |
| 3795 | + |
| 3796 | + if has_ec_transfer() and get_ec_transfer().is_producer: |
| 3797 | + return {} |
| 3798 | + |
3772 | 3799 | block_size = self.vllm_config.cache_config.block_size |
3773 | 3800 | use_mla = self.vllm_config.model_config.use_mla |
3774 | 3801 | kv_cache_spec: dict[str, KVCacheSpec] = {} |
|
0 commit comments