|
8 | 8 | from collections.abc import Iterator |
9 | 9 | from contextlib import contextmanager |
10 | 10 | from copy import deepcopy |
11 | | -from typing import TYPE_CHECKING, Any, Optional, Union, cast |
| 11 | +from typing import Any, cast, Optional, Tuple, TYPE_CHECKING, Union |
12 | 12 |
|
13 | 13 | import numpy as np |
14 | 14 | import torch |
@@ -1452,6 +1452,39 @@ def _pool( |
1452 | 1452 | kv_connector_output=kv_connector_output, |
1453 | 1453 | ) |
1454 | 1454 |
|
| 1455 | + def _forward( |
| 1456 | + self, |
| 1457 | + attn_metadata: dict[str, Any], |
| 1458 | + num_input_tokens: int, |
| 1459 | + num_tokens_across_dp: int, |
| 1460 | + cudagraph_runtime_mode: CUDAGraphMode, |
| 1461 | + batch_descriptor: BatchDescriptor, |
| 1462 | + scheduler_output: "SchedulerOutput", |
| 1463 | + input_ids: torch.Tensor, |
| 1464 | + positions: torch.Tensor, |
| 1465 | + intermediate_tensors: IntermediateTensors, |
| 1466 | + inputs_embeds: list[torch.Tensor], |
| 1467 | + model_kwargs: dict[str, Any], |
| 1468 | + ) -> Tuple[torch.Tensor, Optional[KVConnectorOutput]]: |
| 1469 | + with set_forward_context( |
| 1470 | + attn_metadata, |
| 1471 | + self.vllm_config, |
| 1472 | + num_tokens=num_input_tokens, |
| 1473 | + num_tokens_across_dp=num_tokens_across_dp, |
| 1474 | + cudagraph_runtime_mode=cudagraph_runtime_mode, |
| 1475 | + batch_descriptor=batch_descriptor, |
| 1476 | + ), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output: |
| 1477 | + return ( |
| 1478 | + self.model( |
| 1479 | + input_ids=input_ids, |
| 1480 | + positions=positions, |
| 1481 | + intermediate_tensors=intermediate_tensors, |
| 1482 | + inputs_embeds=inputs_embeds, |
| 1483 | + **model_kwargs, |
| 1484 | + ), |
| 1485 | + kv_connector_output, |
| 1486 | + ) |
| 1487 | + |
1455 | 1488 | @torch.inference_mode() |
1456 | 1489 | def execute_model( |
1457 | 1490 | self, |
@@ -1557,23 +1590,19 @@ def execute_model( |
1557 | 1590 |
|
1558 | 1591 | # Run the model. |
1559 | 1592 | # Use persistent buffers for CUDA graphs. |
1560 | | - with set_forward_context( |
1561 | | - attn_metadata, |
1562 | | - self.vllm_config, |
1563 | | - num_tokens=num_input_tokens, |
1564 | | - num_tokens_across_dp=num_tokens_across_dp, |
1565 | | - cudagraph_runtime_mode=cudagraph_runtime_mode, |
1566 | | - batch_descriptor=batch_descriptor, |
1567 | | - ), self.maybe_get_kv_connector_output( |
1568 | | - scheduler_output) as kv_connector_output: |
1569 | | - |
1570 | | - model_output = self.model( |
1571 | | - input_ids=input_ids, |
1572 | | - positions=positions, |
1573 | | - intermediate_tensors=intermediate_tensors, |
1574 | | - inputs_embeds=inputs_embeds, |
1575 | | - **model_kwargs, |
1576 | | - ) |
| 1593 | + model_output, kv_connector_output = self._forward( |
| 1594 | + attn_metadata, |
| 1595 | + num_input_tokens=num_input_tokens, |
| 1596 | + num_tokens_across_dp=num_tokens_across_dp, |
| 1597 | + cudagraph_runtime_mode=cudagraph_runtime_mode, |
| 1598 | + batch_descriptor=batch_descriptor, |
| 1599 | + scheduler_output=scheduler_output, |
| 1600 | + input_ids=input_ids, |
| 1601 | + positions=positions, |
| 1602 | + intermediate_tensors=intermediate_tensors, |
| 1603 | + inputs_embeds=inputs_embeds, |
| 1604 | + model_kwargs=model_kwargs, |
| 1605 | + ) |
1577 | 1606 |
|
1578 | 1607 | if self.use_aux_hidden_state_outputs: |
1579 | 1608 | hidden_states, aux_hidden_states = model_output |
|
0 commit comments