Skip to content

Commit f86e7f7

Browse files
authored
Fix bug with unit sized dims and block_sizes (#932)
1 parent 1aaba3f commit f86e7f7

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

helion/_compiler/compile_environment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,11 @@ def from_config(
582582
@dataclasses.dataclass
583583
class LoopSpecBlockSizeSource(BlockSizeSource):
584584
def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int:
585-
index = CompileEnvironment.current().config_spec.block_sizes.block_id_to_index(
586-
block_size_info.block_id
587-
)
585+
env = CompileEnvironment.current()
586+
size = block_size_info.size
587+
if isinstance(size, (int, torch.SymInt)) and env.known_equal(size, 1):
588+
return 1
589+
index = env.config_spec.block_sizes.block_id_to_index(block_size_info.block_id)
588590
return config.block_sizes[index]
589591

590592

test/test_matmul.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from helion._testing import code_and_output
1515
from helion._testing import import_path
1616
from helion._testing import skipIfRefEager
17+
from helion._testing import skipIfRocm
1718
import helion.language as hl
1819

1920
torch.backends.cuda.matmul.fp32_precision = "tf32"
@@ -272,6 +273,33 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
272273
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
273274
self.assertExpectedJournal(code)
274275

276+
@skipIfRocm("ROCm triton error in TritonAMDGPUBlockPingpong")
277+
@skipIfRefEager("config_spec is not supported in ref eager mode")
278+
def test_matmul_config_reuse_with_unit_dim(self):
279+
torch.manual_seed(0)
280+
big_args = (
281+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
282+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
283+
)
284+
big_bound = matmul_with_addmm.bind(big_args)
285+
big_spec = big_bound.config_spec
286+
self.assertEqual(len(big_spec.block_sizes), 3)
287+
big_config = big_spec.default_config()
288+
289+
small_args = (
290+
torch.randn([1, 64], device=DEVICE, dtype=torch.float32),
291+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
292+
)
293+
small_bound = matmul_with_addmm.bind(small_args)
294+
small_spec = small_bound.config_spec
295+
self.assertEqual(len(small_spec.block_sizes), 3)
296+
297+
# Previously raised when reusing configs tuned on larger shapes.
298+
small_bound.set_config(big_config)
299+
result = small_bound(*small_args)
300+
expected = small_args[0] @ small_args[1]
301+
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
302+
275303
def test_matmul_packed_rhs(self):
276304
@helion.kernel(static_shapes=False)
277305
def matmul_with_packed_b(

0 commit comments

Comments
 (0)