Skip to content

Commit e8c62c7

Browse files
committed
up
1 parent 4e279bb commit e8c62c7

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def propagate_call(
645645
attr = self.attr()
646646
if attr in {"dim", "ndimension"} and not (args or kwargs):
647647
return TypeInfo.from_example(self.tensor.fake_value.ndim, origin)
648-
if attr in {"shape", "size"} and not kwargs:
648+
if attr in {"shape", "size", "stride"} and not kwargs:
649649
fn = getattr(self.tensor.fake_value, attr)
650650
try:
651651
return TypeInfo.from_example(

helion/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class SpecializeOnDevice(BaseError):
186186

187187

188188
class SpecializeArgType(BaseError):
189-
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
189+
message = "hl.specialize() must be called on a size or stride from an input tensor, got: {}"
190190

191191

192192
class StackTensorcOnHost(BaseError):

helion/runtime/kernel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,14 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:
624624

625625
def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
626626
if isinstance(v, TensorPropertySource):
627-
assert v.prop == TensorProperty.SIZE
628627
index = v.idx
629628
assert index is not None
630629
inner = make_extractor(v.base)
631-
632-
return lambda args: cast("torch.Tensor", inner(args)).size(index)
630+
if v.prop == TensorProperty.SIZE:
631+
return lambda args: cast("torch.Tensor", inner(args)).size(index)
632+
if v.prop == TensorProperty.STRIDE:
633+
return lambda args: cast("torch.Tensor", inner(args)).stride(index)
634+
raise exc.SpecializeArgType(v)
633635
if isinstance(v, LocalSource):
634636
index = arg_name_to_index[v.local_name]
635637
return operator.itemgetter(index)

0 commit comments

Comments
 (0)