Skip to content

Commit 879dc98

Browse files
committed
Fix bug with unit sized dims and block_sizes
stack-info: PR: #932, branch: jansel/stack/191
1 parent 33b32e5 commit 879dc98

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

helion/_compiler/compile_environment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,9 +584,11 @@ def from_config(
584584
@dataclasses.dataclass
585585
class LoopSpecBlockSizeSource(BlockSizeSource):
586586
def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int:
587-
index = CompileEnvironment.current().config_spec.block_sizes.block_id_to_index(
588-
block_size_info.block_id
589-
)
587+
env = CompileEnvironment.current()
588+
size = block_size_info.size
589+
if isinstance(size, (int, torch.SymInt)) and env.known_equal(size, 1):
590+
return 1
591+
index = env.config_spec.block_sizes.block_id_to_index(block_size_info.block_id)
590592
return config.block_sizes[index]
591593

592594

test/test_matmul.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,31 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
272272
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
273273
self.assertExpectedJournal(code)
274274

275+
def test_matmul_config_reuse_with_unit_dim(self):
276+
torch.manual_seed(0)
277+
big_args = (
278+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
279+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
280+
)
281+
big_bound = matmul_with_addmm.bind(big_args)
282+
big_spec = big_bound.config_spec
283+
self.assertEqual(len(big_spec.block_sizes), 3)
284+
big_config = big_spec.default_config()
285+
286+
small_args = (
287+
torch.randn([1, 64], device=DEVICE, dtype=torch.float32),
288+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
289+
)
290+
small_bound = matmul_with_addmm.bind(small_args)
291+
small_spec = small_bound.config_spec
292+
self.assertEqual(len(small_spec.block_sizes), 3)
293+
294+
# Previously raised when reusing configs tuned on larger shapes.
295+
small_bound.set_config(big_config)
296+
result = small_bound(*small_args)
297+
expected = small_args[0] @ small_args[1]
298+
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
299+
275300
def test_matmul_packed_rhs(self):
276301
@helion.kernel(static_shapes=False)
277302
def matmul_with_packed_b(

0 commit comments

Comments
 (0)