@@ -272,6 +272,32 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
272272 torch .testing .assert_close (result , expected , atol = 1e-1 , rtol = 1e-2 )
273273 self .assertExpectedJournal (code )
274274
275+ @skipIfRefEager ("config_spec is not supported in ref eager mode" )
276+ def test_matmul_config_reuse_with_unit_dim (self ):
277+ torch .manual_seed (0 )
278+ big_args = (
279+ torch .randn ([64 , 64 ], device = DEVICE , dtype = torch .float32 ),
280+ torch .randn ([64 , 64 ], device = DEVICE , dtype = torch .float32 ),
281+ )
282+ big_bound = matmul_with_addmm .bind (big_args )
283+ big_spec = big_bound .config_spec
284+ self .assertEqual (len (big_spec .block_sizes ), 3 )
285+ big_config = big_spec .default_config ()
286+
287+ small_args = (
288+ torch .randn ([1 , 64 ], device = DEVICE , dtype = torch .float32 ),
289+ torch .randn ([64 , 64 ], device = DEVICE , dtype = torch .float32 ),
290+ )
291+ small_bound = matmul_with_addmm .bind (small_args )
292+ small_spec = small_bound .config_spec
293+ self .assertEqual (len (small_spec .block_sizes ), 3 )
294+
295+ # Previously raised when reusing configs tuned on larger shapes.
296+ small_bound .set_config (big_config )
297+ result = small_bound (* small_args )
298+ expected = small_args [0 ] @ small_args [1 ]
299+ torch .testing .assert_close (result , expected , atol = 1e-1 , rtol = 1e-2 )
300+
275301 def test_matmul_packed_rhs (self ):
276302 @helion .kernel (static_shapes = False )
277303 def matmul_with_packed_b (
0 commit comments