Skip to content

Commit 3643fee

Browse files
committed
fix grok-1 tests
Signed-off-by: Mamta Singh <[email protected]>
1 parent 6ba051c commit 3643fee

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

QEfficient/utils/test_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,34 @@ class ModelConfig:
165165
}
166166

167167
EXTERNAL_MODELS = {
168-
"hpcai-tech/grok-1",
168+
"hpcai-tech/grok-1": {
169+
"pytorch_hf_tokens": [
170+
391,
171+
391,
172+
391,
173+
391,
174+
391,
175+
391,
176+
391,
177+
391,
178+
391,
179+
391,
180+
391,
181+
391,
182+
391,
183+
391,
184+
391,
185+
391,
186+
391,
187+
391,
188+
391,
189+
391,
190+
391,
191+
391,
192+
391,
193+
391,
194+
]
195+
}
169196
}
170197

171198
SWIFTKV_MODELS = {

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ dependencies = [
3939
"fire",
4040
"py7zr",
4141
"torchmetrics==1.7.0",
42-
"torch==2.7.1; platform_machine=='aarch64'",
42+
"torch==2.7.0; platform_machine=='aarch64'",
4343
# Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11
4444
"torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'",
45-
"torch@https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'",
46-
"torch@https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'",
45+
"torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'",
46+
"torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'",
4747
]
4848

4949
[project.optional-dependencies]

tests/transformers/models/test_causal_lm_models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,10 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
170170
Constants.CTX_LEN,
171171
)
172172

173-
if model_name not in ModelConfig.SWIFTKV_MODELS:
173+
if model_name not in (ModelConfig.SWIFTKV_MODELS and ModelConfig.EXTERNAL_MODELS):
174174
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf)
175+
if model_name in ModelConfig.EXTERNAL_MODELS:
176+
pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens"]
175177

176178
is_tlm = False if num_speculative_tokens is None else True
177179
qeff_model = QEFFAutoModelForCausalLM(
@@ -232,10 +234,13 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
232234
full_batch_size,
233235
)
234236

235-
if model_name not in ModelConfig.SWIFTKV_MODELS:
237+
if model_name not in (ModelConfig.SWIFTKV_MODELS and ModelConfig.EXTERNAL_MODELS):
236238
pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf)
237239
pytorch_hf_tokens = np.vstack(pytorch_hf_tokens)
238240

241+
if model_name in ModelConfig.EXTERNAL_MODELS:
242+
pytorch_hf_tokens = [pytorch_hf_tokens for _ in range(full_batch_size)]
243+
239244
qeff_model = QEFFAutoModelForCausalLM(
240245
model_hf, continuous_batching=True, is_tlm=is_tlm, pretrained_model_name_or_path=model_name
241246
)

0 commit comments

Comments
 (0)