Skip to content

Commit dbabe8f

Browse files
committed
Fix bug with unit sized dims and block_sizes
stack-info: PR: #932, branch: jansel/stack/191
1 parent cf7a08f commit dbabe8f

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

helion/_compiler/compile_environment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,11 @@ def from_config(
582582
@dataclasses.dataclass
583583
class LoopSpecBlockSizeSource(BlockSizeSource):
584584
def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int:
585-
index = CompileEnvironment.current().config_spec.block_sizes.block_id_to_index(
586-
block_size_info.block_id
587-
)
585+
env = CompileEnvironment.current()
586+
size = block_size_info.size
587+
if isinstance(size, (int, torch.SymInt)) and env.known_equal(size, 1):
588+
return 1
589+
index = env.config_spec.block_sizes.block_id_to_index(block_size_info.block_id)
588590
return config.block_sizes[index]
589591

590592

test/test_matmul.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,32 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
272272
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
273273
self.assertExpectedJournal(code)
274274

275+
@skipIfRefEager("config_spec is not supported in ref eager mode")
276+
def test_matmul_config_reuse_with_unit_dim(self):
277+
torch.manual_seed(0)
278+
big_args = (
279+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
280+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
281+
)
282+
big_bound = matmul_with_addmm.bind(big_args)
283+
big_spec = big_bound.config_spec
284+
self.assertEqual(len(big_spec.block_sizes), 3)
285+
big_config = big_spec.default_config()
286+
287+
small_args = (
288+
torch.randn([1, 64], device=DEVICE, dtype=torch.float32),
289+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
290+
)
291+
small_bound = matmul_with_addmm.bind(small_args)
292+
small_spec = small_bound.config_spec
293+
self.assertEqual(len(small_spec.block_sizes), 3)
294+
295+
# Previously raised when reusing configs tuned on larger shapes.
296+
small_bound.set_config(big_config)
297+
result = small_bound(*small_args)
298+
expected = small_args[0] @ small_args[1]
299+
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
300+
275301
def test_matmul_packed_rhs(self):
276302
@helion.kernel(static_shapes=False)
277303
def matmul_with_packed_b(

0 commit comments

Comments
 (0)