|
14 | 14 | from helion._testing import code_and_output |
15 | 15 | from helion._testing import import_path |
16 | 16 | from helion._testing import skipIfRefEager |
| 17 | +from helion._testing import skipIfRocm |
17 | 18 | import helion.language as hl |
18 | 19 |
|
19 | 20 | torch.backends.cuda.matmul.fp32_precision = "tf32" |
@@ -272,6 +273,33 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
272 | 273 | torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2) |
273 | 274 | self.assertExpectedJournal(code) |
274 | 275 |
|
| 276 | + @skipIfRocm("ROCm triton error in TritonAMDGPUBlockPingpong") |
| 277 | + @skipIfRefEager("config_spec is not supported in ref eager mode") |
| 278 | + def test_matmul_config_reuse_with_unit_dim(self): |
| 279 | + torch.manual_seed(0) |
| 280 | + big_args = ( |
| 281 | + torch.randn([64, 64], device=DEVICE, dtype=torch.float32), |
| 282 | + torch.randn([64, 64], device=DEVICE, dtype=torch.float32), |
| 283 | + ) |
| 284 | + big_bound = matmul_with_addmm.bind(big_args) |
| 285 | + big_spec = big_bound.config_spec |
| 286 | + self.assertEqual(len(big_spec.block_sizes), 3) |
| 287 | + big_config = big_spec.default_config() |
| 288 | + |
| 289 | + small_args = ( |
| 290 | + torch.randn([1, 64], device=DEVICE, dtype=torch.float32), |
| 291 | + torch.randn([64, 64], device=DEVICE, dtype=torch.float32), |
| 292 | + ) |
| 293 | + small_bound = matmul_with_addmm.bind(small_args) |
| 294 | + small_spec = small_bound.config_spec |
| 295 | + self.assertEqual(len(small_spec.block_sizes), 3) |
| 296 | + |
| 297 | + # Previously raised when reusing configs tuned on larger shapes. |
| 298 | + small_bound.set_config(big_config) |
| 299 | + result = small_bound(*small_args) |
| 300 | + expected = small_args[0] @ small_args[1] |
| 301 | + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2) |
| 302 | + |
275 | 303 | def test_matmul_packed_rhs(self): |
276 | 304 | @helion.kernel(static_shapes=False) |
277 | 305 | def matmul_with_packed_b( |
|
0 commit comments