@@ -52,50 +52,31 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
5252 topk_indices_buffer )
5353
5454
55- def predictor_forward (
56- self ,
57- input_ids : torch .Tensor ,
58- positions : torch .Tensor ,
59- previous_hidden_states : torch .Tensor ,
60- inputs_embeds : torch .Tensor ,
61- spec_step_index : int = 0 ,
62- ) -> torch .Tensor :
63- assert inputs_embeds is not None
64- inputs_embeds = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
65- inputs_embeds , True )
66- # masking inputs at position 0, as not needed by MTP
67- inputs_embeds [positions == 0 ] = 0
68- inputs_embeds = self .enorm (inputs_embeds )
69- previous_hidden_states = self .hnorm (previous_hidden_states )
70-
71- hidden_states = self .eh_proj (
72- torch .cat ([inputs_embeds , previous_hidden_states ], dim = - 1 ))
73-
74- hidden_states , residual = self .mtp_block (positions = positions ,
75- hidden_states = hidden_states ,
76- residual = None )
77- hidden_states = residual + hidden_states
78- return hidden_states
55+ # def predictor_forward(
56+ # self,
57+ # input_ids: torch.Tensor,
58+ # positions: torch.Tensor,
59+ # previous_hidden_states: torch.Tensor,
60+ # inputs_embeds: torch.Tensor,
61+ # spec_step_index: int = 0,
62+ # ) -> torch.Tensor:
63+ # assert inputs_embeds is not None
64+ # inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
65+ # inputs_embeds, True)
66+ # # masking inputs at position 0, as not needed by MTP
67+ # inputs_embeds[positions == 0] = 0
68+ # inputs_embeds = self.enorm(inputs_embeds)
69+ # previous_hidden_states = self.hnorm(previous_hidden_states)
70+
71+ # hidden_states = self.eh_proj(
72+ # torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
73+
74+ # hidden_states, residual = self.mtp_block(positions=positions,
75+ # hidden_states=hidden_states,
76+ # residual=None)
77+ # hidden_states = residual + hidden_states
78+ # return hidden_states
7979
8080
8181DeepSeekMultiTokenPredictorLayer .__init__ = predictor_init
82- DeepSeekMultiTokenPredictorLayer .forward = predictor_forward
83-
84-
85- def mtp_forward (
86- self ,
87- input_ids : torch .Tensor ,
88- positions : torch .Tensor ,
89- hidden_states : torch .Tensor ,
90- intermediate_tensors : IntermediateTensors ,
91- inputs_embeds : torch .Tensor ,
92- spec_step_idx : int = 0 ,
93- ) -> torch .Tensor :
94- hidden_states = self .model (input_ids , positions , hidden_states ,
95- inputs_embeds , spec_step_idx )
96- hidden_states = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
97- hidden_states , True )
98- return hidden_states
99-
100-
101- DeepSeekMTP .forward = mtp_forward
82+ # DeepSeekMultiTokenPredictorLayer.forward = predictor_forward
0 commit comments