@@ -272,6 +272,31 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
272272 torch .testing .assert_close (result , expected , atol = 1e-1 , rtol = 1e-2 )
273273 self .assertExpectedJournal (code )
274274
275+ def test_matmul_config_reuse_with_unit_dim (self ):
276+ torch .manual_seed (0 )
277+ big_args = (
278+ torch .randn ([64 , 64 ], device = DEVICE , dtype = torch .float32 ),
279+ torch .randn ([64 , 64 ], device = DEVICE , dtype = torch .float32 ),
280+ )
281+ big_bound = matmul_with_addmm .bind (big_args )
282+ big_spec = big_bound .config_spec
283+ self .assertEqual (len (big_spec .block_sizes ), 3 )
284+ big_config = big_spec .default_config ()
285+
286+ small_args = (
287+ torch .randn ([1 , 64 ], device = DEVICE , dtype = torch .float32 ),
288+ torch .randn ([64 , 64 ], device = DEVICE , dtype = torch .float32 ),
289+ )
290+ small_bound = matmul_with_addmm .bind (small_args )
291+ small_spec = small_bound .config_spec
292+ self .assertEqual (len (small_spec .block_sizes ), 3 )
293+
294+ # Previously raised when reusing configs tuned on larger shapes.
295+ small_bound .set_config (big_config )
296+ result = small_bound (* small_args )
297+ expected = small_args [0 ] @ small_args [1 ]
298+ torch .testing .assert_close (result , expected , atol = 1e-1 , rtol = 1e-2 )
299+
275300 def test_matmul_packed_rhs (self ):
276301 @helion .kernel (static_shapes = False )
277302 def matmul_with_packed_b (
0 commit comments