From 0dc8cacfe9ce1ae158cf861f1ac9b333720c1649 Mon Sep 17 00:00:00 2001 From: slokesha Date: Mon, 27 Oct 2025 20:55:51 -0700 Subject: [PATCH 1/2] Workaround for Assertion error when embedding with bge-m3 in lazy mode Signed-off-by: slokesha --- vllm/model_executor/models/roberta.py | 44 --------------------------- vllm/worker/hpu_model_runner.py | 13 +++++++- 2 files changed, 12 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index ee88f5fbe819..dd4c15126cfc 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -124,19 +124,6 @@ def forward_hpu( pos_list.append(position_ids[offset]) token_list.append(input_ids[offset]) - for index, (positions, tokens, seq_len) in enumerate( - zip(pos_list, token_list, seq_lens)): - # Verify assumption that incoming position are - # always a sequence from 0 to N. - expected_pos = torch.arange(positions.size()[0], - dtype=torch.long, - device=inputs_embeds.device) - valid_input_mask = expected_pos < seq_len - expected_pos = expected_pos * valid_input_mask - assert torch.equal(positions, expected_pos) - position_ids[index] = create_position_ids_from_input_ids_hpu( - tokens, self.padding_idx, seq_len) - # Position embeddings. position_embeddings = self.position_embeddings(position_ids) if token_type_ids is None: @@ -207,37 +194,6 @@ def forward_cuda( return self.forward_native(input_ids, seq_lens, position_ids, token_type_ids) - -# Adapted from transformers -def create_position_ids_from_input_ids_hpu(input_ids, - padding_idx, - seq_len, - past_key_values_length=0): - """ - Replace non-padding symbols with their position numbers. - Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - - Args: - x: torch.Tensor x: - - Returns: torch.Tensor - """ - # The series of casts and type-conversions here are carefully - # balanced to both work with ONNX export and XLA. - valid_input_mask = torch.arange(input_ids.size()[0], - dtype=torch.int, - device=input_ids.device) - valid_input_mask = valid_input_mask < seq_len - - mask = input_ids.ne(padding_idx).int() - - incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + - past_key_values_length) * mask - - return (incremental_indices.long() + padding_idx) * valid_input_mask - - # Adapted from transformers class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index e01b79f00a5f..9f482ed0ae44 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1579,7 +1579,18 @@ def _prepare_prompt( input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append(list(range(context_len, seq_len))) + if "RobertaEmbeddingModel" in str(type(self.model.model)): + padding_idx = getattr(self.model.model.model.embeddings, "padding_idx", 1) + tokens_cpu = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu").clone().contiguous() + mask = tokens_cpu.ne(padding_idx).to(torch.int32) + incremental_indices = (torch.cumsum(mask, dim=0).to(torch.int32) * mask) + pos_cpu = incremental_indices.to(torch.int64) + padding_idx + if seq_len < pos_cpu.numel(): + pos_cpu[seq_len:] = 0 + pos_hpu = pos_cpu.to("hpu", non_blocking=False) + input_positions.append(pos_hpu.tolist()) + else: + input_positions.append(list(range(context_len, seq_len))) seq_data_mrope_positions: Optional[List[List[int]]] = None From f34ca0da5dc8c61443cabacecb46c8207d4a051c Mon Sep 17 00:00:00 2001 From: slokesha Date: Mon, 27 Oct 2025 21:23:05 -0700 Subject: [PATCH 2/2] Precommit Fix Signed-off-by: slokesha --- vllm/worker/hpu_model_runner.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 9f482ed0ae44..059ba4dae8c5 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1580,10 +1580,14 @@ def _prepare_prompt( # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. if "RobertaEmbeddingModel" in str(type(self.model.model)): - padding_idx = getattr(self.model.model.model.embeddings, "padding_idx", 1) - tokens_cpu = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu").clone().contiguous() + padding_idx = getattr(self.model.model.model.embeddings, + "padding_idx", 1) + tokens_cpu = torch.tensor(prompt_tokens, + dtype=torch.long, + device="cpu").clone().contiguous() mask = tokens_cpu.ne(padding_idx).to(torch.int32) - incremental_indices = (torch.cumsum(mask, dim=0).to(torch.int32) * mask) + incremental_indices = ( + torch.cumsum(mask, dim=0).to(torch.int32) * mask) pos_cpu = incremental_indices.to(torch.int64) + padding_idx if seq_len < pos_cpu.numel(): pos_cpu[seq_len:] = 0