Skip to content

Commit 3c3478f

Browse files
committed
Add test for iterating over np.ndarray.
1 parent 9187dac commit 3c3478f

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/unit_tests/test_nodes.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from opto.trace import node
33
from opto.trace import operators as ops
44
from opto.trace.utils import contain
5-
5+
import numpy as np
66

77
# Sum of str
88
x = node("NodeX")
@@ -151,4 +151,11 @@ def fun(x):
151151
assert x.description == "[ParameterNode] x"
152152

153153
x = node(1, trainable=True)
154-
assert x.description == "[ParameterNode] This is a ParameterNode in a computational graph."
154+
assert x.description == "[ParameterNode] This is a ParameterNode in a computational graph."
155+
156+
157+
# Test iterating numpy array
158+
x = node(np.array([1, 2, 3]))
159+
for i, v in enumerate(x):
160+
assert isinstance(v, type(x))
161+
assert v.data == x.data[i]

0 commit comments

Comments
 (0)