Skip to content

Commit 70cb5b6

Browse files
committed
[mlir] Use MLIR op names when generating FileCheck variables in generate-test-checks.py
Motivation ---------- Improve readability and stability of autogenerated CHECK lines by using operation-aware FileCheck variable names instead of generic VAL_N. What changes ------------ - When possible, variable names are derived from the MLIR op name, e.g. `vector.transfer_read` → `TRANSFER_READ_0`. - Unknown ops (e.g., from out-of-tree dialects) fall back to the prior `VAL_N` scheme. Before ------ ```mlir // CHECK: %[[VAL_4:.*]] = vector.transfer_read ... // CHECK: %[[VAL_5:.*]] = "val_use"(%[[VAL_4]]) : ... ``` After ----- ```mlir // CHECK: %[[TRANSFER_READ_0:.*]] = vector.transfer_read ... // CHECK: %[[VAL_1:.*]] = "val_use"(%[[TRANSFER_READ_0]]) : ... ``` Rationale --------- Using op-derived names (e.g., `TRANSFER_READ_0`) makes tests easier to read and audit, while remaining more stable across unrelated edits (e.g. there will always be fewer `TRANSFER_READ_#N` variables than `VAL_#N`). The fallback to `VAL_N` preserves compatibility for unknown ops.
1 parent 0d989b2 commit 70cb5b6

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

mlir/utils/generate-test-checks.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import os # Used to advertise this file's name ("autogenerated_note").
3232
import re
3333
import sys
34+
from collections import Counter
3435

3536
ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
3637
ADVERT_END = """
@@ -45,6 +46,14 @@
4546
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
4647
SSA_RE = re.compile(SSA_RE_STR)
4748

49+
# Regex matching `dialect.op_name`, where `dialect` is an upstream MLIR
50+
# dialect (e.g. `vector.transfer_read`).
51+
DIALECTS = "acc|affine|amdgpu|amx|arith|arm_neon|arm_sve|arm_sme|async|bufferization|cf|complex|dlti|emitc|\
52+
func|gpu|index|irdl|linalg|llvm|math|memref|ml_program|mpi|nvgpu|nvvm|omp|pdl_interp|pdl|ptr|quant|\
53+
rocdl|scf|shape|shard|smt|sparse_tensor|tensor|ub|vcix|vector|wasmssa|x86vector|xegpu|xevm|spirv|tosa|\
54+
transform"
55+
SSA_OP_NAME_RE = re.compile(rf"\b(?:{DIALECTS})[.]([a-z_]+)\b")
56+
4857
# Regex matching the left-hand side of an assignment
4958
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
5059
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
@@ -63,7 +72,12 @@
6372
class VariableNamer:
6473
def __init__(self, variable_names):
6574
self.scopes = []
75+
# Counter for generic FileCHeck names, e.g. VAL_#N
6676
self.name_counter = 0
77+
# Counters for FileCheck names derived from Op names, e.g.
78+
# TRANSFER_READ_#N (based on `vector.transfer_read`). Note, there's a
79+
# dedicated counter for every Op type present in the input.
80+
self.op_name_counter = Counter()
6781

6882
# Number of variable names to still generate in parent scope
6983
self.generate_in_parent_scope_left = 0
@@ -77,17 +91,29 @@ def generate_in_parent_scope(self, n):
7791
self.generate_in_parent_scope_left = n
7892

7993
# Generate a substitution name for the given ssa value name.
80-
def generate_name(self, source_variable_name, use_ssa_name):
94+
def generate_name(self, source_variable_name, use_ssa_name, op_name=""):
8195

8296
# Compute variable name
83-
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
84-
if variable_name == '':
97+
variable_name = (
98+
self.variable_names.pop(0) if len(self.variable_names) > 0 else ""
99+
)
100+
if variable_name == "":
85101
# If `use_ssa_name` is set, use the MLIR SSA value name to generate
86102
# a FileCHeck substation string. As FileCheck requires these
87103
# strings to start with a character, skip MLIR variables starting
88104
# with a digit (e.g. `%0`).
105+
#
106+
# The next fallback option is to use the op name, if the
107+
# corresponding match succeeds.
108+
#
109+
# If neither worked, use a generic name: `VAL_#N`.
89110
if use_ssa_name and source_variable_name[0].isalpha():
90111
variable_name = source_variable_name.upper()
112+
elif op_name != "":
113+
variable_name = (
114+
op_name.upper() + "_" + str(self.op_name_counter[op_name])
115+
)
116+
self.op_name_counter[op_name] += 1
91117
else:
92118
variable_name = "VAL_" + str(self.name_counter)
93119
self.name_counter += 1
@@ -123,6 +149,7 @@ def num_scopes(self):
123149
def clear_names(self):
124150
self.name_counter = 0
125151
self.used_variable_names = set()
152+
self.op_name_counter.clear()
126153

127154
class AttributeNamer:
128155

@@ -170,8 +197,12 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re
170197

171198
# Process the rest that contained an SSA value name.
172199
for chunk in line_chunks:
173-
m = SSA_RE.match(chunk)
174-
ssa_name = m.group(0) if m is not None else ''
200+
ssa = SSA_RE.match(chunk)
201+
op_name_with_dialect = SSA_OP_NAME_RE.search(chunk)
202+
ssa_name = ssa.group(0) if ssa is not None else ""
203+
op_name = (
204+
op_name_with_dialect.group(1) if op_name_with_dialect is not None else ""
205+
)
175206

176207
# Check if an existing variable exists for this name.
177208
variable = None
@@ -185,7 +216,7 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re
185216
output_line += "%[[" + variable + "]]"
186217
else:
187218
# Otherwise, generate a new variable.
188-
variable = variable_namer.generate_name(ssa_name, use_ssa_name)
219+
variable = variable_namer.generate_name(ssa_name, use_ssa_name, op_name)
189220
if strict_name_re:
190221
# Use stricter regexp for the variable name, if requested.
191222
# Greedy matching may cause issues with the generic '.*'

0 commit comments

Comments
 (0)