|
37 | 37 |
|
38 | 38 | if prefill_context_parallel_enable(): |
39 | 39 | from vllm.distributed import get_pcp_group |
| 40 | +if shared_expert_dp_enabled(): |
| 41 | + from vllm.distributed import get_tensor_model_parallel_world_size |
40 | 42 |
|
41 | 43 | from vllm.utils.platform_utils import is_pin_memory_available |
42 | 44 | from vllm.utils.torch_utils import set_default_torch_dtype |
@@ -298,6 +300,10 @@ def dummy_run(self, |
298 | 300 | self.model(input_ids=input_ids, |
299 | 301 | positions=positions, |
300 | 302 | hidden_states=previous_hidden_states) |
| 303 | + if self.enable_shared_expert_dp: |
| 304 | + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(positions,True) |
| 305 | + previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( |
| 306 | + previous_hidden_states,True) |
301 | 307 | forward_context = get_forward_context() |
302 | 308 | if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ |
303 | 309 | not forward_context.capturing: |
@@ -690,6 +696,12 @@ def _propose( |
690 | 696 | (self.num_speculative_tokens + 1)) |
691 | 697 | batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, |
692 | 698 | uniform_decode=uniform_decode) |
| 699 | + # Enabling sp/shared_expert_dp will perform educe_scatter operation. |
| 700 | + if self.enable_shared_expert_dp: |
| 701 | + tp_world_size = get_tensor_model_parallel_world_size() |
| 702 | + reduce_num_input_tokens = num_input_tokens // tp_world_size |
| 703 | + batch_descriptor = BatchDescriptor(num_tokens=reduce_num_input_tokens, |
| 704 | + uniform_decode=uniform_decode) |
693 | 705 | else: |
694 | 706 | batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, |
695 | 707 | uniform_decode=False) |
@@ -741,12 +753,15 @@ def _propose( |
741 | 753 | positions = torch.ops.vllm.maybe_pad_and_reduce( |
742 | 754 | positions) |
743 | 755 | positions = positions.squeeze(-1) |
| 756 | + hidden_states = torch.ops.vllm.maybe_pad_and_reduce( |
| 757 | + hidden_states) |
744 | 758 |
|
745 | 759 | hidden_states = self.model(input_ids=input_ids, |
746 | 760 | positions=positions, |
747 | 761 | hidden_states=hidden_states) |
748 | 762 | hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( |
749 | 763 | hidden_states.contiguous(), True) |
| 764 | + |
750 | 765 | forward_context = get_forward_context() |
751 | 766 | if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: |
752 | 767 | if self.vllm_config.model_config.use_mla: |
@@ -821,20 +836,20 @@ def _propose( |
821 | 836 | batch_size, |
822 | 837 | attn_metadata_i.decode.actual_seq_lengths_q) |
823 | 838 | attn_metadata_i.decode.cos = builder.cos_cache[ |
824 | | - positions].unsqueeze(1).unsqueeze(2) |
| 839 | + positions[:batch_size]].unsqueeze(1).unsqueeze(2) |
825 | 840 | attn_metadata_i.decode.sin = builder.sin_cache[ |
826 | | - positions].unsqueeze(1).unsqueeze(2) |
| 841 | + positions[:batch_size]].unsqueeze(1).unsqueeze(2) |
827 | 842 | # NOTE(woosuk): We should handle the case where the draft model |
828 | 843 | # generates tokens beyond the max model length. Since it is complex |
829 | 844 | # to remove such requests from the batch, we keep them in the batch |
830 | 845 | # but adjust the position ids and slot mappings to avoid the |
831 | 846 | # out-of-range access during the model execution. The draft tokens |
832 | 847 | # generated with this adjustment should be ignored. |
833 | | - exceeds_max_model_len = positions >= self.runner.model_config.max_model_len |
| 848 | + exceeds_max_model_len = positions[:batch_size] >= self.runner.model_config.max_model_len |
834 | 849 | # Mask out the position ids that exceed the max model length. |
835 | 850 | # Otherwise, we may get out-of-range error in RoPE. |
836 | 851 | clamped_positions = torch.where(exceeds_max_model_len, 0, |
837 | | - positions) |
| 852 | + positions[:batch_size]) |
838 | 853 | # Increment the sequence lengths. |
839 | 854 | attn_metadata_i.seq_lens[:batch_size] += 1 |
840 | 855 | # For the requests that exceed the max model length, we set the |
|
0 commit comments