-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Use MLIR op names when generating FileCheck variables in generate-test-checks.py #160820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
import os # Used to advertise this file's name ("autogenerated_note"). | ||
import re | ||
import sys | ||
from collections import Counter | ||
|
||
ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by " | ||
ADVERT_END = """ | ||
|
@@ -45,6 +46,14 @@ | |
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" | ||
SSA_RE = re.compile(SSA_RE_STR) | ||
|
||
# Regex matching `dialect.op_name`, where `dialect` is an upstream MLIR | ||
# dialect (e.g. `vector.transfer_read`). | ||
DIALECTS = "acc|affine|amdgpu|amx|arith|arm_neon|arm_sve|arm_sme|async|bufferization|cf|complex|dlti|emitc|\ | ||
func|gpu|index|irdl|linalg|llvm|math|memref|ml_program|mpi|nvgpu|nvvm|omp|pdl_interp|pdl|ptr|quant|\ | ||
rocdl|scf|shape|shard|smt|sparse_tensor|tensor|ub|vcix|vector|wasmssa|x86vector|xegpu|xevm|spirv|tosa|\ | ||
transform" | ||
SSA_OP_NAME_RE = re.compile(rf"\b(?:{DIALECTS})[.]([a-z_]+)\b") | ||
|
||
|
||
# Regex matching the left-hand side of an assignment | ||
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*=' | ||
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR) | ||
|
@@ -63,7 +72,12 @@ | |
class VariableNamer: | ||
def __init__(self, variable_names): | ||
self.scopes = [] | ||
# Counter for generic FileCHeck names, e.g. VAL_#N | ||
self.name_counter = 0 | ||
# Counters for FileCheck names derived from Op names, e.g. | ||
# TRANSFER_READ_#N (based on `vector.transfer_read`). Note, there's a | ||
# dedicated counter for every Op type present in the input. | ||
self.op_name_counter = Counter() | ||
|
||
# Number of variable names to still generate in parent scope | ||
self.generate_in_parent_scope_left = 0 | ||
|
@@ -77,17 +91,29 @@ def generate_in_parent_scope(self, n): | |
self.generate_in_parent_scope_left = n | ||
|
||
# Generate a substitution name for the given ssa value name. | ||
def generate_name(self, source_variable_name, use_ssa_name): | ||
def generate_name(self, source_variable_name, use_ssa_name, op_name=""): | ||
|
||
# Compute variable name | ||
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else '' | ||
if variable_name == '': | ||
variable_name = ( | ||
self.variable_names.pop(0) if len(self.variable_names) > 0 else "" | ||
) | ||
if variable_name == "": | ||
# If `use_ssa_name` is set, use the MLIR SSA value name to generate | ||
# a FileCHeck substation string. As FileCheck requires these | ||
# strings to start with a character, skip MLIR variables starting | ||
# with a digit (e.g. `%0`). | ||
# | ||
# The next fallback option is to use the op name, if the | ||
# corresponding match succeeds. | ||
# | ||
# If neither worked, use a generic name: `VAL_#N`. | ||
if use_ssa_name and source_variable_name[0].isalpha(): | ||
variable_name = source_variable_name.upper() | ||
elif op_name != "": | ||
variable_name = ( | ||
op_name.upper() + "_" + str(self.op_name_counter[op_name]) | ||
) | ||
self.op_name_counter[op_name] += 1 | ||
else: | ||
variable_name = "VAL_" + str(self.name_counter) | ||
self.name_counter += 1 | ||
|
@@ -123,6 +149,7 @@ def num_scopes(self): | |
def clear_names(self): | ||
self.name_counter = 0 | ||
self.used_variable_names = set() | ||
self.op_name_counter.clear() | ||
|
||
class AttributeNamer: | ||
|
||
|
@@ -170,8 +197,12 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re | |
|
||
# Process the rest that contained an SSA value name. | ||
for chunk in line_chunks: | ||
m = SSA_RE.match(chunk) | ||
ssa_name = m.group(0) if m is not None else '' | ||
ssa = SSA_RE.match(chunk) | ||
op_name_with_dialect = SSA_OP_NAME_RE.search(chunk) | ||
ssa_name = ssa.group(0) if ssa is not None else "" | ||
op_name = ( | ||
op_name_with_dialect.group(1) if op_name_with_dialect is not None else "" | ||
) | ||
|
||
# Check if an existing variable exists for this name. | ||
variable = None | ||
|
@@ -185,7 +216,7 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re | |
output_line += "%[[" + variable + "]]" | ||
else: | ||
# Otherwise, generate a new variable. | ||
variable = variable_namer.generate_name(ssa_name, use_ssa_name) | ||
variable = variable_namer.generate_name(ssa_name, use_ssa_name, op_name) | ||
if strict_name_re: | ||
# Use stricter regexp for the variable name, if requested. | ||
# Greedy matching may cause issues with the generic '.*' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this list of dialects? The MLIR syntax is quite strict: there is necessarily a
=
after the SSA value(s) and then the op name is always formed bydialect.op_name
. The RE should be able to find this without hardcoding any list.