|
14 | 14 |
|
15 | 15 | import sympy |
16 | 16 | import torch |
| 17 | +from torch._dynamo.source import LocalSource |
17 | 18 | from torch._inductor.codegen.triton import TritonPrinter |
18 | 19 | from torch.fx.graph import _Namespace |
19 | 20 |
|
@@ -602,11 +603,23 @@ def tensor_size(self, fake_value: torch.Tensor, dim: int) -> Argument: |
602 | 603 | return self._tensor_property(TensorSizeArg, fake_value, dim, "size") |
603 | 604 |
|
604 | 605 | def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument: |
| 606 | + v = fake_value.stride(dim) |
| 607 | + env = CompileEnvironment.current() |
| 608 | + # Check if this stride was explicitly specialized |
| 609 | + source = env.input_sources.get(fake_value) |
605 | 610 | if ( |
606 | | - isinstance(v := fake_value.stride(dim), int) |
607 | | - and CompileEnvironment.current().settings.static_shapes |
| 611 | + isinstance(source, LocalSource) |
| 612 | + and (source.local_name, dim) in env.specialized_strides |
608 | 613 | ): |
609 | | - return StaticShape(v) |
| 614 | + return StaticShape(int(v)) |
| 615 | + if isinstance(v, int): |
| 616 | + if env.settings.static_shapes: |
| 617 | + return StaticShape(v) |
| 618 | + else: |
| 619 | + # Check if all free symbols are specialized |
| 620 | + syms = v._sympy_().free_symbols |
| 621 | + if syms and syms <= env.specialized_vars: |
| 622 | + return StaticShape(int(v)) |
610 | 623 | return self._tensor_property(TensorStrideArg, fake_value, dim, "stride") |
611 | 624 |
|
612 | 625 | def sorted_args(self) -> list[Argument]: |
|
0 commit comments