Skip to content

Commit 224832a

Browse files
committed
Working E2E pybind runner
1 parent a681d18 commit 224832a

File tree

6 files changed

+45
-36
lines changed

6 files changed

+45
-36
lines changed

optimum/executorch/modeling.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ def text_generation(
772772
raise ValueError(
773773
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
774774
)
775-
if not verify_eos_tokens_in_tokenizer(self.eos_token_ids, self.tokenizer):
775+
if not verify_eos_tokens_in_pretrained_tokenizer(self.eos_token_ids, self.tokenizer):
776776
raise ValueError(
777777
f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}."
778778
)
@@ -1300,25 +1300,28 @@ def generate(
13001300
)
13011301
max_seq_len = self.max_cache_size
13021302

1303+
# Prefill.
13031304
self.stats.on_sampling_begin()
13041305
logits = self.forward(
1305-
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
1306-
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
1306+
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device),
1307+
cache_position=torch.arange(len(prompt_tokens[0]), dtype=torch.long, device=self.device),
13071308
input_features=input_features,
13081309
)
13091310
self.stats.on_sampling_end()
1310-
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
13111311
self.stats.on_prompt_eval_end()
1312-
first_token_generated = False
13131312

1314-
generated_tokens = prompt_tokens + [next_token]
1313+
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
1314+
generated_tokens = [next_token]
1315+
print(self.tokenizer.decode([next_token]), end="")
13151316

1316-
while len(generated_tokens) < max_seq_len:
1317+
# Token-by-token generation.
1318+
first_token_generated = False
1319+
while len(generated_tokens) + len(prompt_tokens) < max_seq_len:
13171320
self.stats.on_sampling_begin()
13181321
logits = self.forward(
13191322
input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0),
13201323
cache_position=torch.tensor(
1321-
[pos_base + len(generated_tokens) - 1],
1324+
[pos_base + len(generated_tokens) + len(prompt_tokens) - 1],
13221325
dtype=torch.long,
13231326
device=self.device,
13241327
),
@@ -1328,20 +1331,20 @@ def generate(
13281331
self.stats.on_first_token()
13291332
first_token_generated = True
13301333

1331-
next_token = torch.argmax(logits, dim=-1).item()
1334+
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
13321335
generated_tokens.append(next_token)
1336+
print(self.tokenizer.decode([next_token]), end="")
13331337

1334-
if next_token in self.eos_token_ids:
1338+
if next_token == self.eos_token_id:
13351339
break
13361340

13371341
self.stats.set_num_generated_tokens(len(generated_tokens) - len(prompt_tokens))
1338-
13391342
return generated_tokens if echo else generated_tokens[len(prompt_tokens) :]
13401343

13411344
def text_generation(
13421345
self,
13431346
processor: "ProcessorMixin",
1344-
tokenizer: "PreTrainedTokenizer",
1347+
tokenizer: PreTrainedTokenizer,
13451348
input_conversation: List[Dict],
13461349
echo: bool = True,
13471350
max_seq_len: Optional[int] = None,
@@ -1368,22 +1371,21 @@ def text_generation(
13681371
raise ValueError(
13691372
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
13701373
)
1371-
if not verify_eos_tokens_in_tokenizer(self.eos_token_ids, self.tokenizer):
1374+
if isinstance(self.tokenizer, PreTrainedTokenizer) and verify_eos_tokens_in_pretrained_tokenizer(self.eos_token_id, self.tokenizer):
13721375
raise ValueError(
1373-
f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}."
1376+
f"The tokenizer's eos_token_id does not match with the model's eos_token_id={self.eos_token_id}."
13741377
)
13751378

13761379
# Reset stats for a new generation
13771380
self.stats.reset()
13781381
self.stats.on_inference_start()
13791382

13801383
inputs = processor.apply_chat_template(input_conversation)
1381-
prompt_tokens = self.tokenizer.encode(inputs["input_ids"])
13821384
self.stats.on_token_encode_end()
1383-
self.stats.set_num_prompt_tokens(len(prompt_tokens))
1385+
self.stats.set_num_prompt_tokens(len(inputs["input_ids"][0]))
13841386

13851387
generated_tokens = self.generate(
1386-
prompt_tokens=prompt_tokens,
1388+
prompt_tokens=inputs["input_ids"],
13871389
input_features=inputs["input_features"],
13881390
echo=echo,
13891391
max_seq_len=max_seq_len,

optimum/exporters/executorch/integrations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,17 +671,19 @@ def export(
671671
exported_programs["token_embeddings"] = token_embeddings_exported_program
672672

673673
# 3. Export encoder.
674+
input_ids = torch.zeros_like(inputs_embeds[:, :, 0], dtype=torch.long)
675+
input_ids[0, 1] = self.config.audio_token_id # Make sure we don't have an all-false mask for the imput_embeds.
674676
if isinstance(self.model, VoxtralForConditionalGeneration):
675677
# TODO(JZ): specific to Voxtral, should generalize.
676678
chunk_length = self.model.audio_tower.config.max_source_positions * self.model.audio_tower.conv1.stride[0] * self.model.audio_tower.conv2.stride[0]
677679
encoder_input_kwargs = {
678680
"input_features": torch.rand(3, 128, chunk_length), # (bsz, features, seq_len)
679681
"inputs_embeds": inputs_embeds,
680-
"input_ids": inputs_embeds[:, :, 0],
682+
"input_ids": input_ids,
681683
}
682684

683685
max_audio_len = 150 # In s, should be a multiple of 30. TODO(JZ): make this configurable top-level.
684-
max_seq_len = self.metadata.get("get_max_seq_len") - 1 # TODO(JZ): why - 1? Copied from Gemma3 draft PR.
686+
max_seq_len = self.metadata.get("get_max_seq_len")
685687
dynamic_shapes = {
686688
"input_features": {
687689
0: torch.export.Dim("enc_batch_size_dim", min=1, max=max_audio_len//30),

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,16 @@ def _lower_to_executorch(
8989
et_prog = et_prog.to_executorch(
9090
config=ExecutorchBackendConfig(**backend_config_dict),
9191
)
92-
logging.debug(
93-
f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program().graph_module}"
94-
)
95-
delegation_info = get_delegation_info(et_prog.exported_program().graph_module)
96-
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
97-
logging.debug(
98-
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
99-
)
92+
for method in et_prog.methods:
93+
logging.debug(f"---------------------- Method: {method} ----------------------")
94+
logging.debug(
95+
f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program(method).graph_module}"
96+
)
97+
delegation_info = get_delegation_info(et_prog.exported_program(method).graph_module)
98+
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
99+
logging.debug(
100+
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
101+
)
100102
return {pte_name: et_prog}
101103

102104
exported_progs = model.export()

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
6161
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
6262
cache_implementation = kwargs.get("cache_implementation", "static")
6363
use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa"
64+
qlinear_config = kwargs.get("qlinear", None)
65+
qembedding_config = kwargs.get("qembedding", None)
6466
max_length = kwargs.get("max_length", 2048)
6567
config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path)
6668

@@ -111,8 +113,6 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
111113

112114
# TODO: Move quantization recipe out for better composability.
113115
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
114-
qlinear_config = kwargs.get("qlinear", None)
115-
qembedding_config = kwargs.get("qembedding", None)
116116
if qlinear_config or qembedding_config:
117117
# TODO: Update torchao to use 0.11.0 once released
118118
if parse(torchao.__version__) < parse("0.11.0.dev0"):

optimum/exporters/executorch/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
from transformers import GenerationConfig, PretrainedConfig
19+
from transformers.tokenization_utils import PreTrainedTokenizer
1920

2021

2122
def save_config_to_constant_methods(
@@ -65,7 +66,7 @@ def save_config_to_constant_methods(
6566
return {k: v for k, v in {**metadata, **kwargs}.items() if v is not None}
6667

6768

68-
def verify_eos_tokens_in_tokenizer(model_eos_ids: List[int], tokenizer) -> bool:
69+
def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenizer: PreTrainedTokenizer) -> bool:
6970
"""
7071
Verifies that the model's EOS token IDs are present in the tokenizer's
7172
set of potential end-of-sequence tokens.

tests/models/test_modeling_voxtral.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import transformers
2828
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
2929
from packaging.version import parse
30-
from transformers import AutoTokenizer, AutoProcessor
30+
from transformers import AutoConfig, AutoTokenizer, AutoProcessor
3131
from transformers.testing_utils import slow
3232

3333
from optimum.utils.import_utils import is_transformers_version
@@ -42,7 +42,7 @@
4242

4343
os.environ["TOKENIZERS_PARALLELISM"] = "false"
4444

45-
logging.basicConfig(level=logging.INFO)
45+
logging.basicConfig(level=logging.DEBUG)
4646

4747

4848
@pytest.mark.skipif(
@@ -71,15 +71,16 @@ def __init__(self, *args, **kwargs):
7171
# reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0",
7272
# )
7373
# @pytest.mark.skipif(is_linux_ci, reason="OOM on linux runner")
74-
@pytest.mark.skip()
74+
# @pytest.mark.skip()
7575
def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we_exported_program(self):
7676
model_id = "mistralai/Voxtral-Mini-3B-2507"
77+
config = AutoConfig.from_pretrained(model_id)
7778
module = load_multimodal_text_to_text_model(
7879
model_id,
7980
use_custom_sdpa=True,
8081
use_custom_kv_cache=True,
8182
qlinear=True,
82-
qembedding_config=True,
83+
qembedding=True,
8384
)
8485

8586
res = module.export()
@@ -166,11 +167,12 @@ def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8
166167
]
167168

168169
model = ExecuTorchModelForMultiModalToText.from_pretrained(
169-
model_id,
170+
# model_id,
171+
"/Users/jackzhxng/Documents/voxtral", # Load already exported model in local file path.
170172
recipe="xnnpack",
171173
attn_implementation="custom_sdpa",
172174
use_custom_kv_cache=True,
173-
**{"qlinear": True, "qembeeding": True, "task": "multimodal-text-to-text"},
175+
**{"qlinear": True, "qembedding": True, "task": "multimodal-text-to-text"},
174176
)
175177
self.assertIsInstance(model, ExecuTorchModelForMultiModalToText)
176178
self.assertIsInstance(model.model, ExecuTorchModule)

0 commit comments

Comments
 (0)