@@ -112,6 +112,12 @@ def __init__(
112112 collections .Counter ()
113113 )
114114 self .specialized_vars : set [sympy .Symbol ] = set ()
115+ # Track size symbols (to distinguish from stride-only symbols)
116+ self .size_symbols : set [sympy .Symbol ] = set ()
117+ # Track stride symbols: maps sympy.Symbol to (tensor, dim) for strides
118+ self .stride_symbols : dict [sympy .Symbol , tuple [torch .Tensor , int ]] = {}
119+ # Track which (tensor, dim) strides have been specialized
120+ self .specialized_strides : set [tuple [torch .Tensor , int ]] = set ()
115121 self .loop_dependency_checker = LoopDependencyChecker ()
116122 self ._symint_cache : dict [object , torch .SymInt ] = {}
117123 self .device_load_count = (
@@ -469,6 +475,15 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
469475 self .debug_shape_renames [s ._sympy_ ()] = sympy .Symbol (
470476 f"{ source .local_name } _size{ i } " , integer = True
471477 )
478+ # Record size and stride symbols for specialization tracking
479+ for i in range (result .ndim ):
480+ sz = result .size (i )
481+ if isinstance (sz , torch .SymInt ):
482+ self .size_symbols .update (sz ._sympy_ ().free_symbols )
483+ st = result .stride (i )
484+ if isinstance (st , torch .SymInt ):
485+ for sym in st ._sympy_ ().free_symbols :
486+ self .stride_symbols [sym ] = (result , i )
472487 return result
473488
474489 def size_hint (self , n : int | torch .SymInt ) -> int :
0 commit comments