Skip to content

Commit 0e4cb60

Browse files
committed
up
1 parent ec0806b commit 0e4cb60

File tree

7 files changed

+45
-15
lines changed

7 files changed

+45
-15
lines changed

helion/_compiler/compile_environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
collections.Counter()
113113
)
114114
self.specialized_vars: set[sympy.Symbol] = set()
115+
self.specialized_strides: set[tuple[str, int]] = set()
115116
self.loop_dependency_checker = LoopDependencyChecker()
116117
self._symint_cache: dict[object, torch.SymInt] = {}
117118
self.device_load_count = (

helion/_compiler/device_function.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import sympy
1616
import torch
17+
from torch._dynamo.source import LocalSource
1718
from torch._inductor.codegen.triton import TritonPrinter
1819
from torch.fx.graph import _Namespace
1920

@@ -602,11 +603,23 @@ def tensor_size(self, fake_value: torch.Tensor, dim: int) -> Argument:
602603
return self._tensor_property(TensorSizeArg, fake_value, dim, "size")
603604

604605
def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument:
606+
v = fake_value.stride(dim)
607+
env = CompileEnvironment.current()
608+
# Check if this stride was explicitly specialized
609+
source = env.input_sources.get(fake_value)
605610
if (
606-
isinstance(v := fake_value.stride(dim), int)
607-
and CompileEnvironment.current().settings.static_shapes
611+
isinstance(source, LocalSource)
612+
and (source.local_name, dim) in env.specialized_strides
608613
):
609-
return StaticShape(v)
614+
return StaticShape(int(v))
615+
if isinstance(v, int):
616+
if env.settings.static_shapes:
617+
return StaticShape(v)
618+
else:
619+
# Check if all free symbols are specialized
620+
syms = v._sympy_().free_symbols
621+
if syms and syms <= env.specialized_vars:
622+
return StaticShape(int(v))
610623
return self._tensor_property(TensorStrideArg, fake_value, dim, "stride")
611624

612625
def sorted_args(self) -> list[Argument]:

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def propagate_call(
645645
attr = self.attr()
646646
if attr in {"dim", "ndimension"} and not (args or kwargs):
647647
return TypeInfo.from_example(self.tensor.fake_value.ndim, origin)
648-
if attr in {"shape", "size"} and not kwargs:
648+
if attr in {"shape", "size", "stride"} and not kwargs:
649649
fn = getattr(self.tensor.fake_value, attr)
650650
try:
651651
return TypeInfo.from_example(

helion/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class SpecializeOnDevice(BaseError):
186186

187187

188188
class SpecializeArgType(BaseError):
189-
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
189+
message = "hl.specialize() must be called on a size or stride from an input tensor, got: {}"
190190

191191

192192
class StackTensorcOnHost(BaseError):

helion/language/constexpr.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from typing_extensions import TypeVar
77

88
import torch
9+
from torch._dynamo.source import LocalSource
10+
from torch._dynamo.source import TensorProperty
11+
from torch._dynamo.source import TensorPropertySource
912

1013
from .. import exc
1114
from .._compiler.ast_extension import expr_from_string
@@ -87,7 +90,18 @@ def _(value: TypeInfo, *, origin: Origin) -> TypeInfo:
8790
env = CompileEnvironment.current()
8891

8992
def handle_symint(symint: torch.SymInt) -> int:
90-
env.specialized_vars.update(symint._sympy_().free_symbols)
93+
syms = symint._sympy_().free_symbols
94+
env.specialized_vars.update(syms)
95+
# Track stride specializations
96+
for sym in syms:
97+
for source in env.shape_env.var_to_sources.get(sym, []):
98+
if (
99+
isinstance(source, TensorPropertySource)
100+
and source.prop == TensorProperty.STRIDE
101+
and isinstance(source.base, LocalSource)
102+
and source.idx is not None
103+
):
104+
env.specialized_strides.add((source.base.local_name, source.idx))
91105
return symint.__int__()
92106

93107
specialized = _convert_specializable(proxy, on_symint=handle_symint)

helion/runtime/kernel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,14 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
624624

625625
def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
626626
if isinstance(v, TensorPropertySource):
627-
assert v.prop == TensorProperty.SIZE
628627
index = v.idx
629628
assert index is not None
630629
inner = make_extractor(v.base)
631-
632-
return lambda args: cast("torch.Tensor", inner(args)).size(index)
630+
if v.prop == TensorProperty.SIZE:
631+
return lambda args: cast("torch.Tensor", inner(args)).size(index)
632+
if v.prop == TensorProperty.STRIDE:
633+
return lambda args: cast("torch.Tensor", inner(args)).stride(index)
634+
raise exc.SpecializeArgType(v)
633635
if isinstance(v, LocalSource):
634636
index = arg_name_to_index[v.local_name]
635637
return operator.itemgetter(index)

test/test_examples.expected

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,18 +1674,18 @@ def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, temper
16741674
# src[fused_linear_jsd.py:N]: teacher_div = torch.nn.functional.kl_div(
16751675
# src[fused_linear_jsd.py:N]: torch.log(m), teacher_prob, reduction="none", log_target=True
16761676
# src[fused_linear_jsd.py:N]: ).sum(dim=-1)
1677-
v_17 = teacher_prob_1 - v_16
1678-
v_18 = libdevice.exp(teacher_prob_1)
1679-
v_19 = v_18 * v_17
1677+
v_17 = libdevice.exp(teacher_prob_1)
1678+
v_18 = teacher_prob_1 - v_16
1679+
v_19 = v_17 * v_18
16801680
teacher_div = tl.cast(tl.sum(v_19, 1), tl.float32)
16811681
# src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True
16821682
v_20 = tl_math.log(v_15)
16831683
# src[fused_linear_jsd.py:N]: student_div = torch.nn.functional.kl_div(
16841684
# src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True
16851685
# src[fused_linear_jsd.py:N]: ).sum(dim=-1)
1686-
v_21 = student_prob_1 - v_20
1687-
v_22 = libdevice.exp(student_prob_1)
1688-
v_23 = v_22 * v_21
1686+
v_21 = libdevice.exp(student_prob_1)
1687+
v_22 = student_prob_1 - v_20
1688+
v_23 = v_21 * v_22
16891689
student_div = tl.cast(tl.sum(v_23, 1), tl.float32)
16901690
# src[fused_linear_jsd.py:N]: batch_loss = student_div + beta * (teacher_div - student_div)
16911691
v_24 = teacher_div - student_div

0 commit comments

Comments
 (0)