|
2 | 2 |
|
3 | 3 | import ast |
4 | 4 | import enum |
| 5 | +import linecache |
| 6 | +import os |
5 | 7 | import re |
| 8 | +import textwrap |
6 | 9 | import threading |
7 | 10 | import typing |
8 | 11 | from typing import TYPE_CHECKING |
9 | 12 | from typing import TypeVar |
10 | 13 |
|
11 | 14 | from .. import exc |
| 15 | +from .output_lines import OutputLines |
12 | 16 | from .source_location import SourceLocation |
| 17 | +from .source_location import UnknownLocation |
13 | 18 | from .source_location import current_location |
14 | 19 |
|
15 | 20 | if TYPE_CHECKING: |
@@ -302,6 +307,148 @@ def visit_Tuple(self, node: ast.Tuple) -> None: |
302 | 307 | super().visit_Tuple(node) |
303 | 308 |
|
304 | 309 |
|
305 | | -def unparse(ast_obj: ast.AST) -> str: |
306 | | - unparser = _TupleParensRemovedUnparser() |
307 | | - return unparser.visit(ast_obj) |
| 310 | +class _LocationAnnotatingOutputLines(OutputLines): |
| 311 | + def __init__(self, parent: ast._Unparser) -> None: # pyright: ignore[reportAttributeAccessIssue] |
| 312 | + super().__init__(parent) |
| 313 | + self._cache: dict[tuple[str, int, int], tuple[str, ...]] = {} |
| 314 | + self._last_location_key: tuple[str, int, int] | None = None |
| 315 | + |
| 316 | + def reset_last_location(self) -> None: |
| 317 | + super().reset_last_location() |
| 318 | + self._last_location_key = None |
| 319 | + |
| 320 | + def insert_location_comment(self, location: object) -> None: |
| 321 | + if not isinstance(location, (SourceLocation, UnknownLocation)): |
| 322 | + location = UnknownLocation() |
| 323 | + key = self._location_key(location) |
| 324 | + if key is None or key == self._last_location_key: |
| 325 | + return |
| 326 | + |
| 327 | + comments = self._comments_for_key(key, location) |
| 328 | + if comments: |
| 329 | + self.insert_comments(comments) |
| 330 | + self._last_location_key = key |
| 331 | + |
| 332 | + def _location_key( |
| 333 | + self, location: SourceLocation | UnknownLocation |
| 334 | + ) -> tuple[str, int, int] | None: |
| 335 | + if not location: |
| 336 | + return ("<unknown>", 0, 0) |
| 337 | + filename = location.filename |
| 338 | + if not filename: |
| 339 | + return None |
| 340 | + start = location.lineno or 0 |
| 341 | + end = location.end_lineno or start |
| 342 | + return (filename, start, end) |
| 343 | + |
| 344 | + def _comments_for_key( |
| 345 | + self, |
| 346 | + key: tuple[str, int, int], |
| 347 | + location: SourceLocation | UnknownLocation, |
| 348 | + ) -> tuple[str, ...]: |
| 349 | + cached = self._cache.get(key) |
| 350 | + if cached is not None: |
| 351 | + return cached |
| 352 | + |
| 353 | + filename, start, end = key |
| 354 | + if not location: |
| 355 | + comments = ("# src[unknown]: [source unavailable]",) |
| 356 | + elif start <= 0: |
| 357 | + comments = ( |
| 358 | + f"# src[{os.path.basename(filename)}:{start}]: [source unavailable]", |
| 359 | + ) |
| 360 | + else: |
| 361 | + lines = linecache.getlines(filename) |
| 362 | + if not lines: |
| 363 | + linecache.checkcache(filename) |
| 364 | + lines = linecache.getlines(filename) |
| 365 | + |
| 366 | + if not lines: |
| 367 | + comments = ( |
| 368 | + f"# src[{os.path.basename(filename)}:{start}]: [source unavailable]", |
| 369 | + ) |
| 370 | + else: |
| 371 | + snippet_full = lines[start - 1 : end] |
| 372 | + if not snippet_full: |
| 373 | + comments = ( |
| 374 | + f"# src[{os.path.basename(filename)}:{start}]: [source unavailable]", |
| 375 | + ) |
| 376 | + else: |
| 377 | + max_lines = 3 |
| 378 | + truncated = len(snippet_full) > max_lines |
| 379 | + snippet = snippet_full[:max_lines] |
| 380 | + dedented = textwrap.dedent("".join(snippet)) |
| 381 | + body_list: list[str] = [] |
| 382 | + base_name = os.path.basename(filename) |
| 383 | + for offset, dedented_line in enumerate(dedented.splitlines()): |
| 384 | + stripped = dedented_line.rstrip() |
| 385 | + if not stripped.strip(): |
| 386 | + continue |
| 387 | + lineno = start + offset |
| 388 | + body_list.append(f"# src[{base_name}:{lineno}]: {stripped}") |
| 389 | + if truncated: |
| 390 | + range_part = f"{start}-{end}" if end != start else f"{start}" |
| 391 | + body_list.append(f"# src[{base_name}:{range_part}]: ...") |
| 392 | + comments = ( |
| 393 | + tuple(body_list) |
| 394 | + if body_list |
| 395 | + else (f"# src[{base_name}:{start}]: [source unavailable]",) |
| 396 | + ) |
| 397 | + |
| 398 | + self._cache[key] = comments |
| 399 | + return comments |
| 400 | + |
| 401 | + |
| 402 | +class _HelionUnparser(_TupleParensRemovedUnparser): |
| 403 | + _indent: int |
| 404 | + |
| 405 | + def __init__( |
| 406 | + self, *args: object, output_origin_lines: bool = True, **kwargs: object |
| 407 | + ) -> None: |
| 408 | + super().__init__(*args, **kwargs) |
| 409 | + if output_origin_lines: |
| 410 | + self.output = _LocationAnnotatingOutputLines(self) |
| 411 | + else: |
| 412 | + self.output = OutputLines(self) |
| 413 | + self._source = self.output |
| 414 | + self._output_origin_lines = output_origin_lines |
| 415 | + |
| 416 | + def visit(self, node: ast.AST) -> str: # type: ignore[override] |
| 417 | + self.output.lines.clear() |
| 418 | + self.output.last_newline = 0 |
| 419 | + self.output.reset_last_location() |
| 420 | + self.traverse(node) |
| 421 | + return "".join(self.output) |
| 422 | + |
| 423 | + def maybe_newline(self) -> None: # type: ignore[override] |
| 424 | + output = getattr(self, "output", None) |
| 425 | + if output is not None and getattr(output, "_skip_next_newline", False): |
| 426 | + output._skip_next_newline = False |
| 427 | + return |
| 428 | + super().maybe_newline() |
| 429 | + |
| 430 | + def traverse(self, node: ast.AST | list[ast.AST]) -> None: # pyright: ignore[reportSignatureIssue] |
| 431 | + if ( |
| 432 | + self._output_origin_lines |
| 433 | + and isinstance(node, ExtendedAST) |
| 434 | + and isinstance(node, ast.stmt) |
| 435 | + ): |
| 436 | + if not isinstance( |
| 437 | + node, |
| 438 | + ( |
| 439 | + ast.FunctionDef, |
| 440 | + ast.AsyncFunctionDef, |
| 441 | + ast.ClassDef, |
| 442 | + ast.Import, |
| 443 | + ast.ImportFrom, |
| 444 | + ), |
| 445 | + ): |
| 446 | + self.output.insert_location_comment(node._location) |
| 447 | + super().traverse(node) |
| 448 | + |
| 449 | + |
| 450 | +def unparse(ast_obj: ast.AST, *, output_origin_lines: bool = True) -> str: |
| 451 | + unparser = _HelionUnparser(output_origin_lines=output_origin_lines) |
| 452 | + result = unparser.visit(ast_obj) |
| 453 | + del unparser.output # break reference cycle |
| 454 | + return result |
0 commit comments