diff --git a/helion/language/ref_tile.py b/helion/language/ref_tile.py index 787e9faf7..f31caa7ad 100644 --- a/helion/language/ref_tile.py +++ b/helion/language/ref_tile.py @@ -1,5 +1,7 @@ from __future__ import annotations +import itertools +import traceback from typing import TYPE_CHECKING from typing import TypeVar @@ -16,6 +18,22 @@ _T = TypeVar("_T") +# Counter for generating unique block_ids in ref mode +_ref_mode_block_id_counter = itertools.count() + +# Dict to map tensor id -> block_ids for tracking (cleared at kernel start) +_tensor_block_ids: dict[int, tuple[int | None, ...]] = {} + +# Patterns indicating library/framework code (not user code) +_LIBRARY_PATH_PATTERNS = ( + "/helion/helion/", + "/torch/", + "/unittest/", + "/pytest/", + "/site-packages/", + " None: + def __init__( + self, begin: int, end: int, block_size: int, block_id: int | None = None + ) -> None: super().__init__() from ..runtime.ref_mode import is_in_ref_mode_context @@ -51,6 +72,9 @@ def __init__(self, begin: int, end: int, block_size: int) -> None: assert is_in_ref_mode_context() self._slice = slice(begin, end, None) self._block_size = block_size + self._block_id = block_id if block_id is not None else next( + _ref_mode_block_id_counter + ) @classmethod def __torch_function__( @@ -150,13 +174,30 @@ def _handle_getitem( args: tuple[object, ...], kwargs: dict[str, object] | None, ) -> object: - """Handle tensor[index] operations.""" + """Handle tensor[index] operations with tile indices.""" tensor, index = args assert isinstance(tensor, torch.Tensor) + # Extract block_ids from RefTile indices + indices = index if isinstance(index, tuple) else (index,) + block_ids: list[int | None] = [] + for idx in indices: + if isinstance(idx, RefTile): + block_ids.append(idx._block_id) + elif not isinstance(idx, int): # slice or other -> adds a dim + block_ids.append(None) + # int indices reduce dims, so don't append + slice_index = convert_tile_indices_to_slices(index) # pyrefly: ignore [bad-index] - return tensor[slice_index] + result = tensor[slice_index] + + # Register result with block_ids for tracking + if block_ids and isinstance(result, torch.Tensor) and result.ndim > 0: + if len(block_ids) == result.ndim: + _tensor_block_ids[id(result)] = tuple(block_ids) + + return result @classmethod def _handle_setitem( @@ -174,7 +215,6 @@ def _handle_setitem( # pyrefly: ignore [bad-index] target_shape = tensor[slice_index].shape - # Slice value tensor to match target shape if needed if ( isinstance(value, torch.Tensor) and value.shape != target_shape @@ -199,6 +239,76 @@ def index(self) -> torch.Tensor: # pyrefly: ignore [bad-override] from .._compiler.compile_environment import CompileEnvironment env = CompileEnvironment.current() - return torch.arange( + data = torch.arange( self._slice.start, self._slice.stop, dtype=torch.int32, device=env.device ) + _tensor_block_ids[id(data)] = (self._block_id,) + return data + + +def reset_ref_mode_block_id_counter() -> None: + """Reset the block_id counter and tracking dict. Called at the start of each ref mode kernel execution.""" + global _ref_mode_block_id_counter + _ref_mode_block_id_counter = itertools.count() + _tensor_block_ids.clear() + + +def get_block_ids(tensor: torch.Tensor) -> tuple[int | None, ...] | None: + """Get block_ids for a tensor if tracked.""" + return _tensor_block_ids.get(id(tensor)) + + +def maybe_set_block_ids(tensor: object, block_ids: tuple[int | None, ...] | None) -> None: + """Set block_ids for a tensor if block_ids is non-empty and matches tensor ndim.""" + if block_ids and isinstance(tensor, torch.Tensor) and len(block_ids) == tensor.ndim: + _tensor_block_ids[id(tensor)] = block_ids + + +def check_broadcast_and_get_result_block_ids( + tensors: list[torch.Tensor], +) -> tuple[int | None, ...] | None: + """Check broadcast compatibility and return result block_ids.""" + # Get tracked tensors (those with block_ids) + tracked: list[tuple[torch.Tensor, tuple[int | None, ...]]] = [] + for t in tensors: + bids = _tensor_block_ids.get(id(t)) + if bids is not None: + tracked.append((t, bids)) + + if not tracked: + return None + + shapes = [[*t.shape] for t, _ in tracked] + bids = [[*b] for _, b in tracked] + max_rank = max(len(s) for s in shapes) + + # Right-align with padding + for i in range(len(shapes)): + pad = max_rank - len(shapes[i]) + shapes[i] = [1] * pad + shapes[i] + bids[i] = [None] * pad + bids[i] + + result: list[int | None] = [] + for d in range(max_rank): + ids_in_dim = {bids[i][d] for i in range(len(tracked)) if shapes[i][d] != 1 and bids[i][d] is not None} + if len(ids_in_dim) >= 2: + _raise_mismatch(d, shapes, bids, ids_in_dim) + result.append(next(iter(ids_in_dim)) if ids_in_dim else None) + return tuple(result) + + +def _raise_mismatch( + dim: int, shapes: list[list[int]], bids: list[list[int | None]], ids_in_dim: set[int], +) -> None: + """Raise ShapeMismatch with location info.""" + fmt = lambda s, b: "[" + ", ".join(f"u{x}" if x is not None else str(y) for y, x in zip(s, b, strict=False)) + "]" + descs = [f"tensor with shape {fmt(s, b)}" for s, b in zip(shapes, bids, strict=False) + if s[dim] != 1 and b[dim] in ids_in_dim][:2] + + loc = "" + for f in reversed(traceback.extract_stack()): + if not any(p in f.filename for p in _LIBRARY_PATH_PATTERNS): + loc = f"\n at {f.filename}:{f.lineno}: {f.line}" + break + + raise exc.ShapeMismatch(descs[0] if descs else "unknown", (descs[1] if len(descs) > 1 else "unknown") + loc) diff --git a/helion/runtime/ref_mode.py b/helion/runtime/ref_mode.py index 2e8a7e4d3..f2ea71de4 100644 --- a/helion/runtime/ref_mode.py +++ b/helion/runtime/ref_mode.py @@ -18,6 +18,10 @@ from .._compiler.compile_environment import tls as ce_tls from .._utils import convert_size_arg from .._utils import create_shape_matching_slices +from ..language.ref_tile import check_broadcast_and_get_result_block_ids +from ..language.ref_tile import get_block_ids +from ..language.ref_tile import maybe_set_block_ids +from ..language.ref_tile import reset_ref_mode_block_id_counter if TYPE_CHECKING: from typing_extensions import Self @@ -73,6 +77,7 @@ def __enter__(self) -> Self: assert getattr(ref_mode_tls, "context", None) is None, ( "RefModeContext already active" ) + reset_ref_mode_block_id_counter() ce_tls.env = self.env ref_mode_tls.context = self self.func_mode.__enter__() @@ -190,7 +195,8 @@ def __torch_function__( if func in self._binary_ops: return self._handle_binary_op(func, args, kwargs) - return super().__torch_function__(func, types, args, kwargs) + # For all other ops, run and propagate block_ids + return self._run_with_block_id_tracking(func, types, args, kwargs) def _handle_mm_with_bias( self, @@ -295,10 +301,19 @@ def _handle_binary_op( # Skip if either operand is not a tensor (e.g., scalar operations) if not (isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor)): - return cast("Callable[..., torch.Tensor]", func)(*args, **kwargs) + result = cast("Callable[..., torch.Tensor]", func)(*args, **kwargs) + # Propagate block_ids for tensor + scalar + if isinstance(lhs, torch.Tensor): + maybe_set_block_ids(result, get_block_ids(lhs)) + return result + + # Check broadcast compatibility (may raise ShapeMismatch) + result_bids = check_broadcast_and_get_result_block_ids([lhs, rhs]) if not self._should_handle_binary_op(lhs, rhs): - return cast("Callable[..., torch.Tensor]", func)(*args, **kwargs) + result = cast("Callable[..., torch.Tensor]", func)(*args, **kwargs) + maybe_set_block_ids(result, result_bids) + return result # Check if this is an in-place operation func_name = getattr(func, "__name__", "") @@ -315,9 +330,10 @@ def _handle_binary_op( lhs[slices], rhs[slices], *args[2:], **kwargs ) - # For in-place ops, the operation already modified lhs, so just return it - # For out-of-place ops, return the computed result - return lhs if is_inplace else result + # For in-place ops, return lhs; for out-of-place ops, return result + final_result = lhs if is_inplace else result + maybe_set_block_ids(final_result, result_bids) + return final_result def _should_handle_binary_op(self, lhs: object, rhs: object) -> bool: """Check if binary operation needs special handling. @@ -349,9 +365,13 @@ def _handle_getitem( args: tuple[object, ...], kwargs: dict[str, object], ) -> torch.Tensor: - """Handle tensor indexing with out-of-bounds index clamping.""" + """Handle tensor indexing with out-of-bounds clamping and block_id tracking.""" tensor = cast("torch.Tensor", args[0]) indices: Any = args[1] + + # First check if the tensor has block_ids that need to be propagated + tensor_bids = get_block_ids(tensor) + is_tuple = isinstance(indices, tuple) indices_list = list(indices) if is_tuple else [indices] @@ -359,7 +379,33 @@ def _handle_getitem( if self._is_int_tensor(idx): indices_list[dim] = torch.clamp(idx, min=0, max=tensor.size(dim) - 1) - return tensor[tuple(indices_list) if is_tuple else indices_list[0]] + result = tensor[tuple(indices_list) if is_tuple else indices_list[0]] + + # Propagate block_ids through indexing + if tensor_bids is not None: + bids = list(tensor_bids) + if not is_tuple: + if indices is None: + new_bids = [None, *bids] + elif isinstance(indices, int): + new_bids = bids[1:] + else: + new_bids = bids + else: + new_bids = [] + dim = 0 + for idx in indices: + if idx is None: + new_bids.append(None) + elif isinstance(idx, int): + dim += 1 + else: + if dim < len(bids): + new_bids.append(bids[dim]) + dim += 1 + maybe_set_block_ids(result, tuple(new_bids)) + + return result def _handle_setitem( self, @@ -393,6 +439,46 @@ def _handle_setitem( tensor[final_indices] = value + def _run_with_block_id_tracking( + self, + func: Callable[..., object], + types: list[type[object]], + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + """Run operation and propagate block_ids through the result.""" + # Collect all input tensors + input_tensors = [x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)] + + # Check for reductions + func_name = getattr(func, "__name__", "") + if func_name in ("sum", "mean", "prod", "max", "min", "std", "var", "any", "all"): + if args and isinstance(args[0], torch.Tensor): + tensor = args[0] + tensor_bids = get_block_ids(tensor) + if tensor_bids is not None: + dim = args[1] if len(args) > 1 else kwargs.get("dim") + result = super().__torch_function__(func, types, args, kwargs) + if dim is not None: + bids = list(tensor_bids) + dims = {dim} if isinstance(dim, int) else set(dim) if isinstance(dim, (list, tuple)) else set() + dims = {d if d >= 0 else len(bids) + d for d in dims} + keepdim = kwargs.get("keepdim", False) + if keepdim: + new_bids = [None if i in dims else b for i, b in enumerate(bids)] + else: + new_bids = [b for i, b in enumerate(bids) if i not in dims] + maybe_set_block_ids(result, tuple(new_bids)) + return result + + # Check broadcast compatibility (may raise ShapeMismatch) + result_bids = check_broadcast_and_get_result_block_ids(input_tensors) + + # Run the operation + result = super().__torch_function__(func, types, args, kwargs) + maybe_set_block_ids(result, result_bids) + return result + def _setup_binary_ops_handling(self) -> None: """Initialize binary operation tracking sets and mappings.""" # Define binary operations and their variants diff --git a/test/test_ref_eager.py b/test/test_ref_eager.py index dc65e98aa..39f84e75d 100644 --- a/test/test_ref_eager.py +++ b/test/test_ref_eager.py @@ -8,6 +8,7 @@ import torch import helion +from helion import exc from helion._testing import DEVICE from helion._testing import TestCase from helion._testing import assert_ref_eager_mode @@ -124,6 +125,110 @@ def kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.arange(8, device=DEVICE, dtype=torch.float32) torch.testing.assert_close(result, expected) + def test_tile_broadcast_shape_mismatch(self): + """Test that implicit broadcasting between different tile dims raises ShapeMismatch.""" + + def invalid_broadcast_kernel_impl(x: torch.Tensor) -> torch.Tensor: + M, N = x.shape + out = torch.zeros_like(x) + for row_tile, col_tile in hl.tile([M, N]): + row_idx = row_tile.index # shape [tile0] + col_idx = col_tile.index # shape [tile1] + # Create 2D grid of differences + diff = row_idx[:, None] - col_idx # shape [tile0, tile1] + # BUG: row_idx has shape [tile0], diff has shape [tile0, tile1] + # PyTorch right-aligns: [tile0] -> [1, tile0] + # Then broadcasts [1, tile0] with [tile0, tile1], which incorrectly + # aligns tile0 dimension with tile1 dimension + result = (row_idx > 0) & (diff >= 0) + out[row_tile, col_tile] = result.float() + return out + + # Test ref mode + ref_kernel = helion.kernel(ref_mode=helion.RefMode.EAGER)( + invalid_broadcast_kernel_impl + ) + with assert_ref_eager_mode(): + x = torch.randn(64, 64, device=DEVICE) + with self.assertRaises(exc.ShapeMismatch) as cm: + ref_kernel(x) + + # Verify error message format matches codegen style with symbolic names + error_msg = str(cm.exception) + self.assertIn("Shape mismatch between", error_msg) + # Should contain symbolic shape names like [1, u0] and [u0, u1] + self.assertIn("[1, u0]", error_msg) + self.assertIn("[u0, u1]", error_msg) + # Should contain stack trace pointing to the exact line + self.assertRegex( + error_msg, + r"test_ref_eager\.py:\d+: result = \(row_idx > 0\) & \(diff >= 0\)", + ) + + # Verify compile mode also raises the same error (wrapped in InvalidConfig) + compile_kernel = helion.kernel(invalid_broadcast_kernel_impl) + x = torch.randn(64, 64, device=DEVICE) + with self.assertRaises((exc.ShapeMismatch, exc.InvalidConfig)) as cm: + compile_kernel(x) + # Check that ShapeMismatch is the root cause with expected message + error = cm.exception + if isinstance(error, exc.InvalidConfig): + self.assertIsInstance(error.__cause__, exc.ShapeMismatch) + error = error.__cause__ + error_msg = str(error) + self.assertIn("[1, u0]", error_msg) + self.assertIn("[u0, u1]", error_msg) + + def test_reduction_broadcast_shape_mismatch(self): + """Test that reduction followed by broadcast with different tile dims fails.""" + + def invalid_reduction_kernel_impl(x: torch.Tensor) -> torch.Tensor: + M, N = x.shape + out = torch.zeros_like(x) + for tile_m, tile_n in hl.tile([M, N]): + data = x[tile_m, tile_n] # [tile0, tile1] + row_sum = data.sum(dim=1) # [tile0] + col_sum = data.sum(dim=0) # [tile1] + # [tile0] + [tile1] - improper broadcast between different tiles + combined = row_sum + col_sum + out[tile_m, tile_n] = combined.unsqueeze(-1).expand( + -1, data.shape[1] + ) + return out + + # Test ref mode + ref_kernel = helion.kernel(ref_mode=helion.RefMode.EAGER)( + invalid_reduction_kernel_impl + ) + with assert_ref_eager_mode(): + x = torch.randn(32, 32, device=DEVICE) + with self.assertRaises(exc.ShapeMismatch) as cm: + ref_kernel(x) + + # Verify error message format + error_msg = str(cm.exception) + self.assertIn("Shape mismatch between", error_msg) + self.assertIn("[u0]", error_msg) + self.assertIn("[u1]", error_msg) + # Should contain stack trace pointing to the exact line + self.assertRegex( + error_msg, r"test_ref_eager\.py:\d+: combined = row_sum \+ col_sum" + ) + + # Verify compile mode also raises the same error (wrapped in InvalidConfig) + compile_kernel = helion.kernel(invalid_reduction_kernel_impl) + x = torch.randn(32, 32, device=DEVICE) + with self.assertRaises((exc.ShapeMismatch, exc.InvalidConfig)) as cm: + compile_kernel(x) + # Check that ShapeMismatch is the root cause with expected message + error = cm.exception + if isinstance(error, exc.InvalidConfig): + self.assertIsInstance(error.__cause__, exc.ShapeMismatch) + error = error.__cause__ + error_msg = str(error) + self.assertIn("[u0]", error_msg) + self.assertIn("[u1]", error_msg) + if __name__ == "__main__": unittest.main()