Skip to content

Commit f990cd1

Browse files
authored
Revert "test_nonzero 测试样例修复 (#3057)" (#3058)
This reverts commit 3c7600e.
1 parent 3c7600e commit f990cd1

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

framework/api/paddlebase/test_nonzero.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,10 @@ def test_nonzero2():
6464
x = paddle.to_tensor(np.array([[1.0, 1.0, 4.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]).astype(np.float32))
6565
as_tuple_ = True
6666
outputs = paddle.nonzero(x, as_tuple_)
67-
res = np.array([
68-
[0, 0, 0, 1, 2],
69-
[0, 1, 2, 1, 2]
70-
]).astype(np.int64)
71-
outputs_np = np.stack([i.numpy() for i in outputs], axis=0)
72-
npt.assert_allclose(outputs_np, res)
67+
res = np.array([[[0], [0], [0], [1], [2]], [[0], [1], [2], [1], [2]]]).astype(np.int64)
68+
for i in range(outputs.__len__()):
69+
out = outputs[i].numpy()
70+
npt.assert_allclose(out, res[i, :, :])
7371

7472

7573
@pytest.mark.api_base_nonzero_parameters
@@ -94,9 +92,10 @@ def test_nonzero4():
9492
x = paddle.to_tensor(np.array([2, 1, 0, 3]).astype(np.int32))
9593
as_tuple_ = True
9694
outputs = paddle.nonzero(x, as_tuple_)
97-
res = np.array([[0, 1, 3]]).astype(np.int64)
98-
outputs_np = np.stack([i.numpy() for i in outputs], axis=0)
99-
npt.assert_allclose(outputs_np, res)
95+
res = np.array([[[0], [1], [3]]]).astype(np.int64)
96+
for i in range(outputs.__len__()):
97+
out = outputs[i].numpy()
98+
npt.assert_allclose(out, res[i, :])
10099

101100

102101
@pytest.mark.api_base_nonzero_parameters
@@ -154,11 +153,12 @@ def test_nonzero6():
154153
outputs = paddle.nonzero(x, as_tuple_)
155154
res = np.array(
156155
[
157-
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2],
158-
[0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1],
159-
[0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1],
160-
[0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1]
156+
[[0.0], [0.0], [0.0], [0.0], [0.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [2.0], [2.0], [2.0], [2.0]],
157+
[[0.0], [0.0], [1.0], [1.0], [1.0], [0.0], [0.0], [0.0], [1.0], [1.0], [1.0], [0.0], [1.0], [1.0], [1.0]],
158+
[[0.0], [0.0], [0.0], [1.0], [1.0], [0.0], [0.0], [1.0], [0.0], [1.0], [1.0], [0.0], [0.0], [1.0], [1.0]],
159+
[[0.0], [1.0], [0.0], [0.0], [1.0], [0.0], [1.0], [1.0], [0.0], [0.0], [1.0], [0.0], [1.0], [0.0], [1.0]],
161160
]
162161
).astype(np.int64)
163-
outputs_np = np.stack([i.numpy() for i in outputs], axis=0)
164-
npt.assert_allclose(outputs_np, res)
162+
for i in range(outputs.__len__()):
163+
out = outputs[i].numpy()
164+
npt.assert_allclose(out, res[i, :, :])

0 commit comments

Comments
 (0)