Skip to content

Commit 3d06ccd

Browse files
Corrected test
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 3946f96 commit 3d06ccd

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch_mlir_e2e_test.framework import TestUtils
1111
from torch_mlir_e2e_test.registry import register_test_case
1212
from torch_mlir_e2e_test.annotations import annotate_args, export
13+
from torch._higher_order_ops.while_loop import while_loop
1314

1415
# ==============================================================================
1516

@@ -87,6 +88,12 @@ class TorchPrimLoopWhileLikeHOPModule(torch.nn.Module):
8788
def __init__(self):
8889
super().__init__()
8990

91+
def body_fn(self, i, x):
92+
return i + 1, x + 1
93+
94+
def cond_fn(self, i, x):
95+
return i < 3
96+
9097
@export
9198
@annotate_args(
9299
[
@@ -95,14 +102,8 @@ def __init__(self):
95102
]
96103
)
97104
def forward(self, x: torch.Tensor) -> torch.Tensor:
98-
from torch._higher_order_ops.while_loop import while_loop
99-
100-
def body_fn(i, x):
101-
return i + 1, x + 1
102-
103105
i0 = torch.tensor(0)
104-
105-
out_i, out_x = while_loop(lambda i, x: i < 3, body_fn, (i0, x))
106+
out_i, out_x = while_loop(self.cond_fn, self.body_fn, (i0, x))
106107
return out_i, out_x
107108

108109

0 commit comments

Comments
 (0)