Skip to content

Commit ffbd40c

Browse files
shared_exp_dp_v2
Signed-off-by: chenmenglong <[email protected]>
1 parent 02f4ddf commit ffbd40c

File tree

3 files changed

+53
-58
lines changed

3 files changed

+53
-58
lines changed

vllm_ascend/models/layers/mla.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,20 @@ def forward(
151151
hidden_states: torch.Tensor,
152152
kv_cache: Optional[torch.Tensor] = None,
153153
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
154-
forward_context = get_forward_context()
155-
sp_enabled = forward_context.sp_enabled
156-
need_gather_q_kv = False
157-
if sp_enabled and self.debug_layer_idx < self.layers:
158-
need_gather_q_kv = True
159-
if not sp_enabled or self.debug_layer_idx < self.layers:
160-
output_shape = hidden_states.shape
161-
else:
162-
# used in deepseek mtp layer
163-
output_shape = torch.chunk(hidden_states, self.tp_size,
164-
dim=0)[0].shape
154+
# forward_context = get_forward_context()
155+
# sp_enabled = forward_context.sp_enabled
156+
# need_gather_q_kv = False
157+
# if sp_enabled and self.debug_layer_idx < self.layers:
158+
# need_gather_q_kv = True
159+
# if not sp_enabled or self.debug_layer_idx < self.layers:
160+
# output_shape = hidden_states.shape
161+
# else:
162+
# # used in deepseek mtp layer
163+
# output_shape = torch.chunk(hidden_states, self.tp_size,
164+
# dim=0)[0].shape
165+
166+
need_gather_q_kv = get_forward_context().sp_enabled
167+
output_shape = hidden_states.shape
165168
# FIXME: This does not seem right, should make sure the buffer is fixed
166169
output = torch.empty(output_shape,
167170
dtype=hidden_states.dtype,

vllm_ascend/patch/worker/patch_deepseek_mtp.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,50 +52,31 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
5252
topk_indices_buffer)
5353

5454

55-
def predictor_forward(
56-
self,
57-
input_ids: torch.Tensor,
58-
positions: torch.Tensor,
59-
previous_hidden_states: torch.Tensor,
60-
inputs_embeds: torch.Tensor,
61-
spec_step_index: int = 0,
62-
) -> torch.Tensor:
63-
assert inputs_embeds is not None
64-
inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
65-
inputs_embeds, True)
66-
# masking inputs at position 0, as not needed by MTP
67-
inputs_embeds[positions == 0] = 0
68-
inputs_embeds = self.enorm(inputs_embeds)
69-
previous_hidden_states = self.hnorm(previous_hidden_states)
70-
71-
hidden_states = self.eh_proj(
72-
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
73-
74-
hidden_states, residual = self.mtp_block(positions=positions,
75-
hidden_states=hidden_states,
76-
residual=None)
77-
hidden_states = residual + hidden_states
78-
return hidden_states
55+
# def predictor_forward(
56+
# self,
57+
# input_ids: torch.Tensor,
58+
# positions: torch.Tensor,
59+
# previous_hidden_states: torch.Tensor,
60+
# inputs_embeds: torch.Tensor,
61+
# spec_step_index: int = 0,
62+
# ) -> torch.Tensor:
63+
# assert inputs_embeds is not None
64+
# inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
65+
# inputs_embeds, True)
66+
# # masking inputs at position 0, as not needed by MTP
67+
# inputs_embeds[positions == 0] = 0
68+
# inputs_embeds = self.enorm(inputs_embeds)
69+
# previous_hidden_states = self.hnorm(previous_hidden_states)
70+
71+
# hidden_states = self.eh_proj(
72+
# torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
73+
74+
# hidden_states, residual = self.mtp_block(positions=positions,
75+
# hidden_states=hidden_states,
76+
# residual=None)
77+
# hidden_states = residual + hidden_states
78+
# return hidden_states
7979

8080

8181
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
82-
DeepSeekMultiTokenPredictorLayer.forward = predictor_forward
83-
84-
85-
def mtp_forward(
86-
self,
87-
input_ids: torch.Tensor,
88-
positions: torch.Tensor,
89-
hidden_states: torch.Tensor,
90-
intermediate_tensors: IntermediateTensors,
91-
inputs_embeds: torch.Tensor,
92-
spec_step_idx: int = 0,
93-
) -> torch.Tensor:
94-
hidden_states = self.model(input_ids, positions, hidden_states,
95-
inputs_embeds, spec_step_idx)
96-
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
97-
hidden_states, True)
98-
return hidden_states
99-
100-
101-
DeepSeekMTP.forward = mtp_forward
82+
# DeepSeekMultiTokenPredictorLayer.forward = predictor_forward

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def dummy_run(self,
190190
kv_caches=self.runner.kv_caches[-1:],
191191
spec_step_idx=0)
192192
else:
193+
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
194+
previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(previous_hidden_states)
193195
self.model(input_ids=input_ids,
194196
positions=positions,
195197
hidden_states=previous_hidden_states)
@@ -474,10 +476,19 @@ def _propose(
474476
spec_step_idx=0,
475477
**model_kwargs)
476478
else:
479+
input_ids=self.input_ids[:num_input_tokens]
480+
positions=self.positions[:num_input_tokens]
481+
hidden_states=self.hidden_states[:num_input_tokens]
482+
483+
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
484+
previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(previous_hidden_states)
477485
hidden_states = self.model(
478-
input_ids=self.input_ids[:num_input_tokens],
479-
positions=self.positions[:num_input_tokens],
480-
hidden_states=self.hidden_states[:num_input_tokens]
486+
input_ids=input_ids,
487+
positions=positions,
488+
hidden_states=hidden_states
489+
)
490+
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
491+
hidden_states.contiguous(), True
481492
)
482493

483494
num_indices = last_token_indices.shape[0]

0 commit comments

Comments
 (0)