Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions mlir/utils/generate-test-checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import os # Used to advertise this file's name ("autogenerated_note").
import re
import sys
from collections import Counter

ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
ADVERT_END = """
Expand All @@ -45,6 +46,14 @@
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
SSA_RE = re.compile(SSA_RE_STR)

# Regex matching `dialect.op_name`, where `dialect` is an upstream MLIR
# dialect (e.g. `vector.transfer_read`).
DIALECTS = "acc|affine|amdgpu|amx|arith|arm_neon|arm_sve|arm_sme|async|bufferization|cf|complex|dlti|emitc|\
func|gpu|index|irdl|linalg|llvm|math|memref|ml_program|mpi|nvgpu|nvvm|omp|pdl_interp|pdl|ptr|quant|\
rocdl|scf|shape|shard|smt|sparse_tensor|tensor|ub|vcix|vector|wasmssa|x86vector|xegpu|xevm|spirv|tosa|\
transform"
SSA_OP_NAME_RE = re.compile(rf"\b(?:{DIALECTS})[.]([a-z_]+)\b")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this list of dialects? The MLIR syntax is quite strict: there is necessarily a = after the SSA value(s) and then the op name is always formed by dialect.op_name . The RE should be able to find this without hardcoding any list.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in the case where the dialect is elided? +1 to removing list of dialects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in the case where the dialect is elided?

In the input IR? Is it ever?

+1 to removing list of dialects.

Done


# Regex matching the left-hand side of an assignment
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
Expand All @@ -63,7 +72,12 @@
class VariableNamer:
def __init__(self, variable_names):
self.scopes = []
# Counter for generic FileCHeck names, e.g. VAL_#N
self.name_counter = 0
# Counters for FileCheck names derived from Op names, e.g.
# TRANSFER_READ_#N (based on `vector.transfer_read`). Note, there's a
# dedicated counter for every Op type present in the input.
self.op_name_counter = Counter()

# Number of variable names to still generate in parent scope
self.generate_in_parent_scope_left = 0
Expand All @@ -77,17 +91,29 @@ def generate_in_parent_scope(self, n):
self.generate_in_parent_scope_left = n

# Generate a substitution name for the given ssa value name.
def generate_name(self, source_variable_name, use_ssa_name):
def generate_name(self, source_variable_name, use_ssa_name, op_name=""):

# Compute variable name
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
if variable_name == '':
variable_name = (
self.variable_names.pop(0) if len(self.variable_names) > 0 else ""
)
if variable_name == "":
# If `use_ssa_name` is set, use the MLIR SSA value name to generate
# a FileCHeck substation string. As FileCheck requires these
# strings to start with a character, skip MLIR variables starting
# with a digit (e.g. `%0`).
#
# The next fallback option is to use the op name, if the
# corresponding match succeeds.
#
# If neither worked, use a generic name: `VAL_#N`.
if use_ssa_name and source_variable_name[0].isalpha():
variable_name = source_variable_name.upper()
elif op_name != "":
variable_name = (
op_name.upper() + "_" + str(self.op_name_counter[op_name])
)
self.op_name_counter[op_name] += 1
else:
variable_name = "VAL_" + str(self.name_counter)
self.name_counter += 1
Expand Down Expand Up @@ -123,6 +149,7 @@ def num_scopes(self):
def clear_names(self):
self.name_counter = 0
self.used_variable_names = set()
self.op_name_counter.clear()

class AttributeNamer:

Expand Down Expand Up @@ -170,8 +197,12 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re

# Process the rest that contained an SSA value name.
for chunk in line_chunks:
m = SSA_RE.match(chunk)
ssa_name = m.group(0) if m is not None else ''
ssa = SSA_RE.match(chunk)
op_name_with_dialect = SSA_OP_NAME_RE.search(chunk)
ssa_name = ssa.group(0) if ssa is not None else ""
op_name = (
op_name_with_dialect.group(1) if op_name_with_dialect is not None else ""
)

# Check if an existing variable exists for this name.
variable = None
Expand All @@ -185,7 +216,7 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re
output_line += "%[[" + variable + "]]"
else:
# Otherwise, generate a new variable.
variable = variable_namer.generate_name(ssa_name, use_ssa_name)
variable = variable_namer.generate_name(ssa_name, use_ssa_name, op_name)
if strict_name_re:
# Use stricter regexp for the variable name, if requested.
# Greedy matching may cause issues with the generic '.*'
Expand Down