Skip to content

Commit 3a0e975

Browse files
authored
Add provenance annotations to output code (#988)
1 parent 69b3fb6 commit 3a0e975

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+9282
-534
lines changed

docs/api/settings.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and b
190190
191191
Print generated Triton code to stderr. Default is ``False``. Controlled by ``HELION_PRINT_OUTPUT_CODE=1``.
192192
193+
.. autoattribute:: Settings.output_origin_lines
194+
195+
Annotate generated Triton code with ``# src[<file>:<line>]`` comments indicating the originating Helion statements.
196+
Default is ``True``. Controlled by ``HELION_OUTPUT_ORIGIN_LINES`` (set to ``0`` to disable).
197+
193198
.. autoattribute:: Settings.ignore_warnings
194199
195200
List of warning types to suppress during compilation. Default is an empty list.
@@ -253,6 +258,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
253258
| ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. |
254259
| ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. |
255260
| ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. |
261+
| ``HELION_OUTPUT_ORIGIN_LINES`` | ``output_origin_lines`` | Include ``# src[...]`` comments in generated Triton code; set to ``0`` to disable. |
256262
| ``HELION_IGNORE_WARNINGS`` | ``ignore_warnings`` | Comma-separated warning names defined in ``helion.exc`` to suppress. |
257263
| ``HELION_ALLOW_WARP_SPECIALIZE`` | ``allow_warp_specialize`` | Permit warp-specialized code generation for ``tl.range``. |
258264
| ``HELION_DEBUG_DTYPE_ASSERTS`` | ``debug_dtype_asserts`` | Inject dtype assertions after each lowering step. |

helion/_compiler/ast_extension.py

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22

33
import ast
44
import enum
5+
import linecache
6+
import os
57
import re
8+
import textwrap
69
import threading
710
import typing
811
from typing import TYPE_CHECKING
912
from typing import TypeVar
1013

1114
from .. import exc
15+
from .output_lines import OutputLines
1216
from .source_location import SourceLocation
17+
from .source_location import UnknownLocation
1318
from .source_location import current_location
1419

1520
if TYPE_CHECKING:
@@ -302,6 +307,148 @@ def visit_Tuple(self, node: ast.Tuple) -> None:
302307
super().visit_Tuple(node)
303308

304309

305-
def unparse(ast_obj: ast.AST) -> str:
306-
unparser = _TupleParensRemovedUnparser()
307-
return unparser.visit(ast_obj)
310+
class _LocationAnnotatingOutputLines(OutputLines):
311+
def __init__(self, parent: ast._Unparser) -> None: # pyright: ignore[reportAttributeAccessIssue]
312+
super().__init__(parent)
313+
self._cache: dict[tuple[str, int, int], tuple[str, ...]] = {}
314+
self._last_location_key: tuple[str, int, int] | None = None
315+
316+
def reset_last_location(self) -> None:
317+
super().reset_last_location()
318+
self._last_location_key = None
319+
320+
def insert_location_comment(self, location: object) -> None:
321+
if not isinstance(location, (SourceLocation, UnknownLocation)):
322+
location = UnknownLocation()
323+
key = self._location_key(location)
324+
if key is None or key == self._last_location_key:
325+
return
326+
327+
comments = self._comments_for_key(key, location)
328+
if comments:
329+
self.insert_comments(comments)
330+
self._last_location_key = key
331+
332+
def _location_key(
333+
self, location: SourceLocation | UnknownLocation
334+
) -> tuple[str, int, int] | None:
335+
if not location:
336+
return ("<unknown>", 0, 0)
337+
filename = location.filename
338+
if not filename:
339+
return None
340+
start = location.lineno or 0
341+
end = location.end_lineno or start
342+
return (filename, start, end)
343+
344+
def _comments_for_key(
345+
self,
346+
key: tuple[str, int, int],
347+
location: SourceLocation | UnknownLocation,
348+
) -> tuple[str, ...]:
349+
cached = self._cache.get(key)
350+
if cached is not None:
351+
return cached
352+
353+
filename, start, end = key
354+
if not location:
355+
comments = ("# src[unknown]: [source unavailable]",)
356+
elif start <= 0:
357+
comments = (
358+
f"# src[{os.path.basename(filename)}:{start}]: [source unavailable]",
359+
)
360+
else:
361+
lines = linecache.getlines(filename)
362+
if not lines:
363+
linecache.checkcache(filename)
364+
lines = linecache.getlines(filename)
365+
366+
if not lines:
367+
comments = (
368+
f"# src[{os.path.basename(filename)}:{start}]: [source unavailable]",
369+
)
370+
else:
371+
snippet_full = lines[start - 1 : end]
372+
if not snippet_full:
373+
comments = (
374+
f"# src[{os.path.basename(filename)}:{start}]: [source unavailable]",
375+
)
376+
else:
377+
max_lines = 3
378+
truncated = len(snippet_full) > max_lines
379+
snippet = snippet_full[:max_lines]
380+
dedented = textwrap.dedent("".join(snippet))
381+
body_list: list[str] = []
382+
base_name = os.path.basename(filename)
383+
for offset, dedented_line in enumerate(dedented.splitlines()):
384+
stripped = dedented_line.rstrip()
385+
if not stripped.strip():
386+
continue
387+
lineno = start + offset
388+
body_list.append(f"# src[{base_name}:{lineno}]: {stripped}")
389+
if truncated:
390+
range_part = f"{start}-{end}" if end != start else f"{start}"
391+
body_list.append(f"# src[{base_name}:{range_part}]: ...")
392+
comments = (
393+
tuple(body_list)
394+
if body_list
395+
else (f"# src[{base_name}:{start}]: [source unavailable]",)
396+
)
397+
398+
self._cache[key] = comments
399+
return comments
400+
401+
402+
class _HelionUnparser(_TupleParensRemovedUnparser):
403+
_indent: int
404+
405+
def __init__(
406+
self, *args: object, output_origin_lines: bool = True, **kwargs: object
407+
) -> None:
408+
super().__init__(*args, **kwargs)
409+
if output_origin_lines:
410+
self.output = _LocationAnnotatingOutputLines(self)
411+
else:
412+
self.output = OutputLines(self)
413+
self._source = self.output
414+
self._output_origin_lines = output_origin_lines
415+
416+
def visit(self, node: ast.AST) -> str: # type: ignore[override]
417+
self.output.lines.clear()
418+
self.output.last_newline = 0
419+
self.output.reset_last_location()
420+
self.traverse(node)
421+
return "".join(self.output)
422+
423+
def maybe_newline(self) -> None: # type: ignore[override]
424+
output = getattr(self, "output", None)
425+
if output is not None and getattr(output, "_skip_next_newline", False):
426+
output._skip_next_newline = False
427+
return
428+
super().maybe_newline()
429+
430+
def traverse(self, node: ast.AST | list[ast.AST]) -> None: # pyright: ignore[reportSignatureIssue]
431+
if (
432+
self._output_origin_lines
433+
and isinstance(node, ExtendedAST)
434+
and isinstance(node, ast.stmt)
435+
):
436+
if not isinstance(
437+
node,
438+
(
439+
ast.FunctionDef,
440+
ast.AsyncFunctionDef,
441+
ast.ClassDef,
442+
ast.Import,
443+
ast.ImportFrom,
444+
),
445+
):
446+
self.output.insert_location_comment(node._location)
447+
super().traverse(node)
448+
449+
450+
def unparse(ast_obj: ast.AST, *, output_origin_lines: bool = True) -> str:
451+
unparser = _HelionUnparser(output_origin_lines=output_origin_lines)
452+
result = unparser.visit(ast_obj)
453+
del unparser.output # break reference cycle
454+
return result

helion/_compiler/output_lines.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations
2+
3+
from typing import Iterable
4+
from typing import Iterator
5+
from typing import Sequence
6+
7+
8+
class OutputLines:
9+
"""
10+
Helper to build source text while keeping track of the most recent newline so
11+
callers can inject annotations before the currently buffered statement.
12+
"""
13+
14+
def __init__(self, parent: object) -> None:
15+
super().__init__()
16+
self.lines: list[str] = []
17+
self.last_newline = 0
18+
self.parent = parent
19+
self._skip_next_newline = False
20+
21+
def extend(self, chunks: Iterable[str]) -> None:
22+
"""Append text while tracking the index after the last newline."""
23+
concatenated = "".join(chunks)
24+
if not concatenated:
25+
return
26+
27+
new_lines = concatenated.splitlines(keepends=True)
28+
self.lines.extend(new_lines)
29+
30+
if new_lines[-1].endswith("\n"):
31+
self.last_newline = len(self.lines)
32+
elif len(new_lines) > 1:
33+
# Second to last line must end in newline if the last one did not.
34+
assert new_lines[-2].endswith("\n")
35+
self.last_newline = len(self.lines) - 1
36+
37+
def append(self, text: str) -> None:
38+
self.extend([text])
39+
40+
def insert_comments(self, comments: Sequence[str]) -> None:
41+
"""Insert comment lines right before the current statement."""
42+
if not comments:
43+
return
44+
if self.lines and not self.lines[-1].endswith("\n"):
45+
self.lines[-1] = f"{self.lines[-1]}\n"
46+
self.last_newline = len(self.lines)
47+
indent = " " * getattr(self.parent, "_indent", 0)
48+
insert_at = min(max(self.last_newline, 0), len(self.lines))
49+
for comment in comments:
50+
assert "\n" not in comment
51+
self.lines.insert(insert_at, f"{indent}{comment}\n")
52+
insert_at += 1
53+
self.last_newline = insert_at
54+
self._skip_next_newline = True
55+
56+
def insert_annotation(self, annotation: str) -> None:
57+
self.insert_comments((f"# {annotation}",))
58+
59+
def reset_last_location(self) -> None:
60+
self._skip_next_newline = False
61+
62+
def insert_location_comment(self, location: object) -> None:
63+
# Base OutputLines does not track source locations; override when needed.
64+
return None
65+
66+
def __bool__(self) -> bool:
67+
return bool(self.lines)
68+
69+
def __len__(self) -> int:
70+
return len(self.lines)
71+
72+
def __iter__(self) -> Iterator[str]:
73+
return iter(self.lines)

helion/_compiler/type_printer.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,10 @@
44

55
from .ast_extension import ExtendedAST
66
from .ast_extension import _TupleParensRemovedUnparser
7+
from .output_lines import OutputLines
78

89
if TYPE_CHECKING:
910
import ast
10-
from collections.abc import Iterator
11-
12-
13-
class OutputLines:
14-
def __init__(self, parent: ASTPrinter) -> None:
15-
super().__init__()
16-
self.lines: list[str] = []
17-
self.last_newline = 0
18-
self.parent = parent
19-
20-
def extend(self, lines: list[str]) -> None:
21-
"""Keep track of the index right after the last newline so insert_annotation can insert in the correct spot."""
22-
lines = "".join(lines).splitlines(keepends=True)
23-
if not lines:
24-
return
25-
self.lines.extend(lines)
26-
if lines[-1].endswith("\n"):
27-
self.last_newline = len(self.lines)
28-
elif len(lines) > 1:
29-
assert lines[-2].endswith("\n")
30-
self.last_newline = len(self.lines) - 1
31-
32-
def __bool__(self) -> bool:
33-
return bool(self.lines)
34-
35-
def __len__(self) -> int:
36-
return len(self.lines)
37-
38-
def __iter__(self) -> Iterator[str]:
39-
return iter(self.lines)
40-
41-
def insert_annotation(self, annotation: str) -> None:
42-
assert "\n" not in annotation
43-
indent = " " * self.parent._indent
44-
self.lines.insert(self.last_newline, f"{indent}# {annotation}\n")
45-
self.last_newline += 1
46-
47-
def append(self, text: str) -> None:
48-
self.extend([text])
4911

5012

5113
class ASTPrinter(_TupleParensRemovedUnparser):

helion/_testing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,8 +772,28 @@ def normalize_codegen_variants(code: str) -> str:
772772
# tl.sqrt( -> tl.sqrt_rn(
773773
return re.sub(r"\btl\.sqrt\s*\(", "tl.sqrt_rn(", code)
774774

775+
@staticmethod
776+
def normalize_source_comment_structure(code: str) -> str:
777+
pattern = re.compile(
778+
r"^(?P<indent>\s*)# src\[(?P<prefix>[^:\]]+:)(?P<start>\d+|N)(?:-(?P<end>\d+|N))?]: (?P<text>.*?)(?P<newline>\r?\n|$)",
779+
flags=re.MULTILINE,
780+
)
781+
782+
def replacer(match: re.Match[str]) -> str:
783+
text = match.group("text").rstrip()
784+
if not text.strip():
785+
return ""
786+
indent = match.group("indent")
787+
prefix = match.group("prefix")
788+
suffix = "N-N" if match.group("end") is not None else "N"
789+
newline = match.group("newline")
790+
return f"{indent}# src[{prefix}{suffix}]: {text}{newline}"
791+
792+
return pattern.sub(replacer, code)
793+
775794
@classmethod
776795
def normalize_code(cls, code: str) -> str:
796+
code = cls.normalize_source_comment_structure(code)
777797
code = cls.normalize_tensor_descriptors(code)
778798
code = cls.normalize_device_name(code)
779799
code = cls.normalize_codegen_variants(code)

helion/runtime/kernel.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,11 @@ def format_kernel_decorator(self, config: Config, settings: Settings) -> str:
394394
return f"@helion.kernel(config={config.__repr__()}, static_shapes={settings.static_shapes})"
395395

396396
def to_triton_code(
397-
self, config: ConfigLike | None = None, emit_repro_caller: bool = False
397+
self,
398+
config: ConfigLike | None = None,
399+
*,
400+
emit_repro_caller: bool = False,
401+
output_origin_lines: bool | None = None,
398402
) -> str:
399403
"""
400404
Generate Triton code for the kernel based on the given configuration.
@@ -413,7 +417,11 @@ def to_triton_code(
413417
config = Config(**config) # pyright: ignore[reportArgumentType]
414418
self.env.config_spec.normalize(config)
415419
root = generate_ast(self.host_function, config, emit_repro_caller)
416-
return get_needed_imports(root) + unparse(root)
420+
if output_origin_lines is None:
421+
output_origin_lines = self.settings.output_origin_lines
422+
return get_needed_imports(root) + unparse(
423+
root, output_origin_lines=output_origin_lines
424+
)
417425

418426
def compile_config(
419427
self, config: ConfigLike | None = None, *, allow_print: bool = True

helion/runtime/settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,11 @@ class _Settings:
314314
_env_get_bool, "HELION_PRINT_OUTPUT_CODE", False
315315
)
316316
)
317+
output_origin_lines: bool = dataclasses.field(
318+
default_factory=functools.partial(
319+
_env_get_bool, "HELION_OUTPUT_ORIGIN_LINES", True
320+
)
321+
)
317322
force_autotune: bool = dataclasses.field(
318323
default_factory=functools.partial(_env_get_bool, "HELION_FORCE_AUTOTUNE", False)
319324
)
@@ -379,6 +384,10 @@ class Settings(_Settings):
379384
"Set HELION_AUTOTUNE_IGNORE_ERRORS=1 to enable globally."
380385
),
381386
"print_output_code": "If True, print the output code of the kernel to stderr.",
387+
"output_origin_lines": (
388+
"If True, annotate generated Triton code with source-origin comments. "
389+
"Set HELION_OUTPUT_ORIGIN_LINES=0 to disable."
390+
),
382391
"force_autotune": "If True, force autotuning even if a config is provided.",
383392
"autotune_config_overrides": (
384393
"Dictionary of config key/value pairs forced during autotuning. "

0 commit comments

Comments
 (0)