Skip to content

Commit de9d90f

Browse files
committed
up
1 parent 41d144f commit de9d90f

File tree

7 files changed

+41
-16
lines changed

7 files changed

+41
-16
lines changed

helion/_compiler/compile_environment.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ def __init__(
112112
collections.Counter()
113113
)
114114
self.specialized_vars: set[sympy.Symbol] = set()
115+
# Track size symbols (to distinguish from stride-only symbols)
116+
self.size_symbols: set[sympy.Symbol] = set()
117+
# Track stride symbols: maps sympy.Symbol to (tensor, dim) for strides
118+
self.stride_symbols: dict[sympy.Symbol, tuple[torch.Tensor, int]] = {}
119+
# Track which (tensor, dim) strides have been specialized
120+
self.specialized_strides: set[tuple[torch.Tensor, int]] = set()
115121
self.loop_dependency_checker = LoopDependencyChecker()
116122
self._symint_cache: dict[object, torch.SymInt] = {}
117123
self.device_load_count = (
@@ -469,6 +475,15 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
469475
self.debug_shape_renames[s._sympy_()] = sympy.Symbol(
470476
f"{source.local_name}_size{i}", integer=True
471477
)
478+
# Record size and stride symbols for specialization tracking
479+
for i in range(result.ndim):
480+
sz = result.size(i)
481+
if isinstance(sz, torch.SymInt):
482+
self.size_symbols.update(sz._sympy_().free_symbols)
483+
st = result.stride(i)
484+
if isinstance(st, torch.SymInt):
485+
for sym in st._sympy_().free_symbols:
486+
self.stride_symbols[sym] = (result, i)
472487
return result
473488

474489
def size_hint(self, n: int | torch.SymInt) -> int:

helion/_compiler/device_function.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,10 +602,12 @@ def tensor_size(self, fake_value: torch.Tensor, dim: int) -> Argument:
602602
return self._tensor_property(TensorSizeArg, fake_value, dim, "size")
603603

604604
def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument:
605-
if (
606-
isinstance(v := fake_value.stride(dim), int)
607-
and CompileEnvironment.current().settings.static_shapes
608-
):
605+
v = fake_value.stride(dim)
606+
env = CompileEnvironment.current()
607+
# Check if this specific stride was specialized
608+
if (fake_value, dim) in env.specialized_strides:
609+
return StaticShape(int(v))
610+
if isinstance(v, int) and env.settings.static_shapes:
609611
return StaticShape(v)
610612
return self._tensor_property(TensorStrideArg, fake_value, dim, "stride")
611613

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def propagate_call(
645645
attr = self.attr()
646646
if attr in {"dim", "ndimension"} and not (args or kwargs):
647647
return TypeInfo.from_example(self.tensor.fake_value.ndim, origin)
648-
if attr in {"shape", "size"} and not kwargs:
648+
if attr in {"shape", "size", "stride"} and not kwargs:
649649
fn = getattr(self.tensor.fake_value, attr)
650650
try:
651651
return TypeInfo.from_example(

helion/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class SpecializeOnDevice(BaseError):
186186

187187

188188
class SpecializeArgType(BaseError):
189-
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
189+
message = "hl.specialize() must be called on a size or stride from an input tensor, got: {}"
190190

191191

192192
class StackTensorcOnHost(BaseError):

helion/language/constexpr.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,13 @@ def _(value: TypeInfo, *, origin: Origin) -> TypeInfo:
8787
env = CompileEnvironment.current()
8888

8989
def handle_symint(symint: torch.SymInt) -> int:
90-
env.specialized_vars.update(symint._sympy_().free_symbols)
90+
syms = symint._sympy_().free_symbols
91+
env.specialized_vars.update(syms)
92+
# Track which specific strides were specialized (by sympy symbol)
93+
# Only mark as specialized stride if symbol is stride-only (not also a size)
94+
for sym in syms:
95+
if sym in env.stride_symbols and sym not in env.size_symbols:
96+
env.specialized_strides.add(env.stride_symbols[sym])
9197
return symint.__int__()
9298

9399
specialized = _convert_specializable(proxy, on_symint=handle_symint)

helion/runtime/kernel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,14 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
624624

625625
def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
626626
if isinstance(v, TensorPropertySource):
627-
assert v.prop == TensorProperty.SIZE
628627
index = v.idx
629628
assert index is not None
630629
inner = make_extractor(v.base)
631-
632-
return lambda args: cast("torch.Tensor", inner(args)).size(index)
630+
if v.prop == TensorProperty.SIZE:
631+
return lambda args: cast("torch.Tensor", inner(args)).size(index)
632+
if v.prop == TensorProperty.STRIDE:
633+
return lambda args: cast("torch.Tensor", inner(args)).stride(index)
634+
raise exc.SpecializeArgType(v)
633635
if isinstance(v, LocalSource):
634636
index = arg_name_to_index[v.local_name]
635637
return operator.itemgetter(index)

test/test_examples.expected

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,18 +1674,18 @@ def _helion_fused_linear_jsd_kernel(student_logits, teacher_logits, loss, temper
16741674
# src[fused_linear_jsd.py:N]: teacher_div = torch.nn.functional.kl_div(
16751675
# src[fused_linear_jsd.py:N]: torch.log(m), teacher_prob, reduction="none", log_target=True
16761676
# src[fused_linear_jsd.py:N]: ).sum(dim=-1)
1677-
v_17 = teacher_prob_1 - v_16
1678-
v_18 = libdevice.exp(teacher_prob_1)
1679-
v_19 = v_18 * v_17
1677+
v_17 = libdevice.exp(teacher_prob_1)
1678+
v_18 = teacher_prob_1 - v_16
1679+
v_19 = v_17 * v_18
16801680
teacher_div = tl.cast(tl.sum(v_19, 1), tl.float32)
16811681
# src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True
16821682
v_20 = tl_math.log(v_15)
16831683
# src[fused_linear_jsd.py:N]: student_div = torch.nn.functional.kl_div(
16841684
# src[fused_linear_jsd.py:N]: torch.log(m), student_prob, reduction="none", log_target=True
16851685
# src[fused_linear_jsd.py:N]: ).sum(dim=-1)
1686-
v_21 = student_prob_1 - v_20
1687-
v_22 = libdevice.exp(student_prob_1)
1688-
v_23 = v_22 * v_21
1686+
v_21 = libdevice.exp(student_prob_1)
1687+
v_22 = student_prob_1 - v_20
1688+
v_23 = v_21 * v_22
16891689
student_div = tl.cast(tl.sum(v_23, 1), tl.float32)
16901690
# src[fused_linear_jsd.py:N]: batch_loss = student_div + beta * (teacher_div - student_div)
16911691
v_24 = teacher_div - student_div

0 commit comments

Comments
 (0)