Skip to content

Commit d1e7e6b

Browse files
authored
enhance(test-benchmark): use config file for fixed opcode count scenarios (#1790)
* enhance(test-benchmark): use config file for fixed opcode count scenarios * chore(test-benchmark): update help messages for fixed opcode count and gas bench values * chore(test-benchmark): fix repricing filter to work with both benchmark options * chore(test-benchmark): allow fixed-opcode-count for all benchmark tests * chore(test-benchmark): warn when config file missing for fixed-opcode-count * chore(test-benchmark): update test to match new help text * chore(test-benchmark): remove unnecessary generic_visit call in parser * chore(test-benchmark): format test file
1 parent f64c2a2 commit d1e7e6b

File tree

7 files changed

+592
-30
lines changed

7 files changed

+592
-30
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
# AI
1010
.claude/
1111

12+
# Benchmark fixed opcode counts
13+
.fixed_opcode_counts.json
14+
1215
# C extensions
1316
*.so
1417

packages/testing/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ extract_config = "execution_testing.cli.extract_config:extract_config"
101101
compare_fixtures = "execution_testing.cli.compare_fixtures:main"
102102
modify_static_test_gas_limits = "execution_testing.cli.modify_static_test_gas_limits:main"
103103
validate_changelog = "execution_testing.cli.tox_helpers:validate_changelog"
104+
benchmark_parser = "execution_testing.cli.benchmark_parser:main"
104105

105106
[tool.setuptools.packages.find]
106107
where = ["src"]
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
"""
2+
Parser to analyze benchmark tests and maintain the opcode counts mapping.
3+
4+
This script uses Python's AST to analyze benchmark tests and generate/update
5+
the scenario configs in `.fixed_opcode_counts.json`.
6+
7+
Usage:
8+
uv run benchmark_parser # Update `.fixed_opcode_counts.json`
9+
uv run benchmark_parser --check # Check for new/missing entries (CI)
10+
"""
11+
12+
import argparse
13+
import ast
14+
import json
15+
import sys
16+
from pathlib import Path
17+
18+
19+
def get_repo_root() -> Path:
20+
"""Get the repository root directory."""
21+
current = Path.cwd()
22+
while current != current.parent:
23+
if (current / "tests" / "benchmark").exists():
24+
return current
25+
current = current.parent
26+
raise FileNotFoundError("Could not find repository root")
27+
28+
29+
def get_benchmark_dir() -> Path:
30+
"""Get the benchmark tests directory."""
31+
return get_repo_root() / "tests" / "benchmark"
32+
33+
34+
def get_config_file() -> Path:
35+
"""Get the .fixed_opcode_counts.json config file path."""
36+
return get_repo_root() / ".fixed_opcode_counts.json"
37+
38+
39+
class OpcodeExtractor(ast.NodeVisitor):
40+
"""Extract opcode parametrizations from benchmark test functions."""
41+
42+
def __init__(self, source_code: str):
43+
self.source_code = source_code
44+
self.patterns: list[str] = []
45+
46+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
47+
"""Visit function definitions and extract opcode patterns."""
48+
if not node.name.startswith("test_"):
49+
return
50+
51+
# Check if function has benchmark_test parameter
52+
if not self._has_benchmark_test_param(node):
53+
return
54+
55+
# Filter for code generator usage (required for fixed-opcode-count mode)
56+
if not self._uses_code_generator(node):
57+
return
58+
59+
# Extract opcode parametrizations
60+
test_name = node.name
61+
opcodes = self._extract_opcodes(node)
62+
63+
if opcodes:
64+
# Test parametrizes on opcodes - create pattern for each
65+
for opcode in opcodes:
66+
pattern = f"{test_name}.*{opcode}.*"
67+
self.patterns.append(pattern)
68+
else:
69+
# Test doesn't parametrize on opcodes - use test name only
70+
pattern = f"{test_name}.*"
71+
self.patterns.append(pattern)
72+
73+
def _has_benchmark_test_param(self, node: ast.FunctionDef) -> bool:
74+
"""Check if function has benchmark_test parameter."""
75+
return any(arg.arg == "benchmark_test" for arg in node.args.args)
76+
77+
def _uses_code_generator(self, node: ast.FunctionDef) -> bool:
78+
"""Check if function body uses code_generator parameter."""
79+
func_start = node.lineno - 1
80+
func_end = node.end_lineno
81+
if func_end is None:
82+
return False
83+
func_source = "\n".join(
84+
self.source_code.splitlines()[func_start:func_end]
85+
)
86+
return "code_generator=" in func_source
87+
88+
def _extract_opcodes(self, node: ast.FunctionDef) -> list[str]:
89+
"""Extract opcode values from @pytest.mark.parametrize decorators."""
90+
opcodes: list[str] = []
91+
92+
for decorator in node.decorator_list:
93+
if not self._is_parametrize_decorator(decorator):
94+
continue
95+
96+
if not isinstance(decorator, ast.Call) or len(decorator.args) < 2:
97+
continue
98+
99+
# Get parameter names (first arg)
100+
param_names = decorator.args[0]
101+
if isinstance(param_names, ast.Constant):
102+
param_str = str(param_names.value).lower()
103+
else:
104+
continue
105+
106+
# Check if "opcode" is in parameter names
107+
if "opcode" not in param_str:
108+
continue
109+
110+
# Extract opcode values from second arg (the list)
111+
param_values = decorator.args[1]
112+
opcodes.extend(self._parse_opcode_values(param_values))
113+
114+
return opcodes
115+
116+
def _is_parametrize_decorator(self, decorator: ast.expr) -> bool:
117+
"""Check if decorator is @pytest.mark.parametrize."""
118+
if isinstance(decorator, ast.Call):
119+
if isinstance(decorator.func, ast.Attribute):
120+
if (
121+
isinstance(decorator.func.value, ast.Attribute)
122+
and decorator.func.value.attr == "mark"
123+
and decorator.func.attr == "parametrize"
124+
):
125+
return True
126+
return False
127+
128+
def _parse_opcode_values(self, values_node: ast.expr) -> list[str]:
129+
"""Parse opcode values from the parametrize list."""
130+
opcodes: list[str] = []
131+
132+
if not isinstance(values_node, (ast.List, ast.Tuple)):
133+
return opcodes
134+
135+
for element in values_node.elts:
136+
opcode_name = self._extract_opcode_name(element)
137+
if opcode_name:
138+
opcodes.append(opcode_name)
139+
140+
return opcodes
141+
142+
def _extract_opcode_name(self, node: ast.expr) -> str | None:
143+
"""
144+
Extract opcode name from various AST node types.
145+
146+
Supported patterns (opcode must be first element):
147+
148+
Case 1 - Direct opcode reference:
149+
@pytest.mark.parametrize("opcode", [Op.ADD, Op.MUL])
150+
Result: ["ADD", "MUL"]
151+
152+
Case 2a - pytest.param with direct opcode:
153+
@pytest.mark.parametrize("opcode", [pytest.param(Op.ADD, id="add")])
154+
Result: ["ADD"]
155+
156+
Case 2b - pytest.param with tuple (opcode first):
157+
@pytest.mark.parametrize("opcode,arg", [pytest.param((Op.ADD, 123))])
158+
Result: ["ADD"]
159+
160+
Case 3 - Plain tuple (opcode first):
161+
@pytest.mark.parametrize("opcode,arg", [(Op.ADD, 123), (Op.MUL, 456)])
162+
Result: ["ADD", "MUL"]
163+
"""
164+
# Case 1: Direct opcode - Op.ADD
165+
if isinstance(node, ast.Attribute):
166+
return node.attr
167+
168+
# Case 2: pytest.param(Op.ADD, ...) or pytest.param((Op.ADD, x), ...)
169+
if isinstance(node, ast.Call):
170+
if len(node.args) > 0:
171+
first_arg = node.args[0]
172+
# Case 2a: pytest.param(Op.ADD, ...)
173+
if isinstance(first_arg, ast.Attribute):
174+
return first_arg.attr
175+
# Case 2b: pytest.param((Op.ADD, x), ...)
176+
elif isinstance(first_arg, ast.Tuple) and first_arg.elts:
177+
first_elem = first_arg.elts[0]
178+
if isinstance(first_elem, ast.Attribute):
179+
return first_elem.attr
180+
181+
# Case 3: Plain tuple - (Op.ADD, args)
182+
if isinstance(node, ast.Tuple) and node.elts:
183+
first_elem = node.elts[0]
184+
if isinstance(first_elem, ast.Attribute):
185+
return first_elem.attr
186+
187+
return None
188+
189+
190+
def scan_benchmark_tests(
191+
base_path: Path,
192+
) -> tuple[dict[str, list[int]], dict[str, Path]]:
193+
"""
194+
Scan benchmark test files and extract opcode patterns.
195+
196+
Returns:
197+
Tuple of (config, pattern_sources) where:
198+
- config: mapping of pattern -> opcode counts
199+
- pattern_sources: mapping of pattern -> source file path
200+
"""
201+
config: dict[str, list[int]] = {}
202+
pattern_sources: dict[str, Path] = {}
203+
default_counts = [1]
204+
205+
test_files = [
206+
f
207+
for f in base_path.rglob("test_*.py")
208+
if "configs" not in str(f) and "stateful" not in str(f)
209+
]
210+
211+
for test_file in test_files:
212+
try:
213+
source = test_file.read_text()
214+
tree = ast.parse(source)
215+
216+
extractor = OpcodeExtractor(source)
217+
extractor.visit(tree)
218+
219+
for pattern in extractor.patterns:
220+
if pattern not in config:
221+
config[pattern] = default_counts
222+
pattern_sources[pattern] = test_file
223+
except Exception as e:
224+
print(f"Warning: Failed to parse {test_file}: {e}")
225+
continue
226+
227+
return config, pattern_sources
228+
229+
230+
def load_existing_config(config_file: Path) -> dict[str, list[int]]:
231+
"""Load existing config from .fixed_opcode_counts.json."""
232+
if not config_file.exists():
233+
return {}
234+
235+
try:
236+
data = json.loads(config_file.read_text())
237+
return data.get("scenario_configs", {})
238+
except (json.JSONDecodeError, KeyError):
239+
return {}
240+
241+
242+
def categorize_patterns(
243+
config: dict[str, list[int]], pattern_sources: dict[str, Path]
244+
) -> dict[str, list[str]]:
245+
"""
246+
Categorize patterns by deriving category from source file name.
247+
248+
Example: test_arithmetic.py -> ARITHMETIC
249+
"""
250+
categories: dict[str, list[str]] = {}
251+
252+
for pattern in config.keys():
253+
if pattern in pattern_sources:
254+
source_file = pattern_sources[pattern]
255+
file_name = source_file.stem
256+
if file_name.startswith("test_"):
257+
category = file_name[5:].upper() # Remove "test_" prefix
258+
else:
259+
category = "OTHER"
260+
else:
261+
category = "OTHER"
262+
263+
if category not in categories:
264+
categories[category] = []
265+
categories[category].append(pattern)
266+
267+
return {k: sorted(v) for k, v in sorted(categories.items())}
268+
269+
270+
def generate_config_json(
271+
config: dict[str, list[int]],
272+
pattern_sources: dict[str, Path],
273+
) -> str:
274+
"""Generate the JSON config file content."""
275+
categories = categorize_patterns(config, pattern_sources)
276+
277+
scenario_configs: dict[str, list[int]] = {}
278+
for _, patterns in categories.items():
279+
for pattern in patterns:
280+
scenario_configs[pattern] = config[pattern]
281+
282+
output = {"scenario_configs": scenario_configs}
283+
284+
return json.dumps(output, indent=2) + "\n"
285+
286+
287+
def main() -> int:
288+
"""Main entry point."""
289+
parser = argparse.ArgumentParser(
290+
description="Analyze benchmark tests and maintain opcode count mapping"
291+
)
292+
parser.add_argument(
293+
"--check",
294+
action="store_true",
295+
help="Check for new/missing entries (CI mode, exits 1 if out of sync)",
296+
)
297+
args = parser.parse_args()
298+
299+
try:
300+
benchmark_dir = get_benchmark_dir()
301+
config_file = get_config_file()
302+
except FileNotFoundError as e:
303+
print(f"Error: {e}", file=sys.stderr)
304+
return 1
305+
306+
print(f"Scanning benchmark tests in {benchmark_dir}...")
307+
detected, pattern_sources = scan_benchmark_tests(benchmark_dir)
308+
print(f"Detected {len(detected)} opcode patterns")
309+
310+
existing = load_existing_config(config_file)
311+
print(f"Loaded {len(existing)} existing entries")
312+
313+
detected_keys = set(detected.keys())
314+
existing_keys = set(existing.keys())
315+
new_patterns = sorted(detected_keys - existing_keys)
316+
obsolete_patterns = sorted(existing_keys - detected_keys)
317+
318+
merged = detected.copy()
319+
for pattern, counts in existing.items():
320+
if pattern in detected_keys:
321+
merged[pattern] = counts
322+
323+
print("\n" + "=" * 60)
324+
print(f"Detected {len(detected)} patterns in tests")
325+
print(f"Existing entries: {len(existing)}")
326+
327+
if new_patterns:
328+
print(f"\n+ Found {len(new_patterns)} NEW patterns:")
329+
for p in new_patterns[:15]:
330+
print(f" {p}")
331+
if len(new_patterns) > 15:
332+
print(f" ... and {len(new_patterns) - 15} more")
333+
334+
if obsolete_patterns:
335+
print(f"\n- Found {len(obsolete_patterns)} OBSOLETE patterns:")
336+
for p in obsolete_patterns[:15]:
337+
print(f" {p}")
338+
if len(obsolete_patterns) > 15:
339+
print(f" ... and {len(obsolete_patterns) - 15} more")
340+
341+
if not new_patterns and not obsolete_patterns:
342+
print("\nConfiguration is up to date!")
343+
344+
print("=" * 60)
345+
346+
if args.check:
347+
if new_patterns or obsolete_patterns:
348+
print("\nRun 'uv run benchmark_parser' (without --check) to sync.")
349+
return 1
350+
return 0
351+
352+
for pattern in obsolete_patterns:
353+
print(f"Removing obsolete: {pattern}")
354+
if pattern in merged:
355+
del merged[pattern]
356+
357+
content = generate_config_json(merged, pattern_sources)
358+
config_file.write_text(content)
359+
print(f"\nUpdated {config_file}")
360+
return 0
361+
362+
363+
if __name__ == "__main__":
364+
raise SystemExit(main())

0 commit comments

Comments
 (0)