Skip to content

Compile mode symbol broadcasting error doesn't repro in interpret mode #1200

@yf225

Description

@yf225

Repro:

import torch
import helion
import helion.language as hl

@helion.kernel
def invalid_broadcast_kernel(x: torch.Tensor) -> torch.Tensor:
    M, N = x.shape
    out = torch.zeros_like(x)
    for row_tile, col_tile in hl.tile([M, N]):
        row_idx = row_tile.index  # shape [tile0]
        col_idx = col_tile.index  # shape [tile1]
        # Create 2D grid of differences
        diff = row_idx[:, None] - col_idx  # shape [tile0, tile1]
        # BUG: row_idx has shape [tile0], diff has shape [tile0, tile1]
        # PyTorch right-aligns: [tile0] -> [1, tile0]
        # Then broadcasts [1, tile0] with [tile0, tile1], which incorrectly
        # aligns tile0 dimension with tile1 dimension
        result = (row_idx > 0) & (diff >= 0)
        out[row_tile, col_tile] = result.float()
    return out

x = torch.randn(64, 64, device="cuda")
invalid_broadcast_kernel(x)

The above script passes with HELION_INTERPRET=1 but fails otherwise. Ideally we should make interpret mode have the same indexing/broadcasting behavior as compile mode, and throw the same error as compile mode for anything unsupported.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions