-
Notifications
You must be signed in to change notification settings - Fork 85
Open
Labels
interpret modeHELION_INTERPRET=1 (ref mode) related issuesHELION_INTERPRET=1 (ref mode) related issuessymbolic shape
Description
Repro:
import torch
import helion
import helion.language as hl
@helion.kernel
def invalid_broadcast_kernel(x: torch.Tensor) -> torch.Tensor:
M, N = x.shape
out = torch.zeros_like(x)
for row_tile, col_tile in hl.tile([M, N]):
row_idx = row_tile.index # shape [tile0]
col_idx = col_tile.index # shape [tile1]
# Create 2D grid of differences
diff = row_idx[:, None] - col_idx # shape [tile0, tile1]
# BUG: row_idx has shape [tile0], diff has shape [tile0, tile1]
# PyTorch right-aligns: [tile0] -> [1, tile0]
# Then broadcasts [1, tile0] with [tile0, tile1], which incorrectly
# aligns tile0 dimension with tile1 dimension
result = (row_idx > 0) & (diff >= 0)
out[row_tile, col_tile] = result.float()
return out
x = torch.randn(64, 64, device="cuda")
invalid_broadcast_kernel(x)The above script passes with HELION_INTERPRET=1 but fails otherwise. Ideally we should make interpret mode have the same indexing/broadcasting behavior as compile mode, and throw the same error as compile mode for anything unsupported.
Metadata
Metadata
Assignees
Labels
interpret modeHELION_INTERPRET=1 (ref mode) related issuesHELION_INTERPRET=1 (ref mode) related issuessymbolic shape