Skip to content

Commit 748524a

Browse files
authored
转换规则 No.16/17/18 (#142)
* 规则转换 No.16/17/18. * 增加cummin,searchsorted转换测试。 * fix error * fix error at SearchsortedMatcher * fix code style error
1 parent 2b1000d commit 748524a

File tree

5 files changed

+332
-31
lines changed

5 files changed

+332
-31
lines changed

paconvert/api_mapping.json

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3398,6 +3398,19 @@
33983398
"device"
33993399
]
34003400
},
3401+
"torch.cummin": {
3402+
"Matcher": "TupleAssignMatcher",
3403+
"paddle_api": "paddle.cummin",
3404+
"args_list": [
3405+
"input",
3406+
"dim",
3407+
"out"
3408+
],
3409+
"kwargs_change": {
3410+
"input": "x",
3411+
"dim": "axis"
3412+
}
3413+
},
34013414
"torch.cumprod": {
34023415
"Matcher": "CumprodMatcher",
34033416
"paddle_api": "paddle.cumprod",
@@ -8418,6 +8431,18 @@
84188431
"dtype": "paddle.float32"
84198432
}
84208433
},
8434+
"torch.searchsorted": {
8435+
"Matcher": "SearchsortedMatcher",
8436+
"args_list": [
8437+
"sorted_sequence",
8438+
"values",
8439+
"out_int32",
8440+
"right",
8441+
"side",
8442+
"out",
8443+
"sorter"
8444+
]
8445+
},
84218446
"torch.seed": {
84228447
"Matcher": "SeedMatcher"
84238448
},
@@ -9214,7 +9239,18 @@
92149239
"tensor": "x"
92159240
}
92169241
},
9217-
"torch.vander": {},
9242+
"torch.vander": {
9243+
"Matcher": "GenericMatcher",
9244+
"paddle_api": "paddle.vander",
9245+
"args_list": [
9246+
"x",
9247+
"N",
9248+
"increasing"
9249+
],
9250+
"kwargs_change": {
9251+
"N": "n"
9252+
}
9253+
},
92189254
"torch.var": {
92199255
"Matcher": "GenericMatcher",
92209256
"paddle_api": "paddle.var",

paconvert/api_matcher.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2867,6 +2867,30 @@ def generate_code(self, kwargs):
28672867
return code
28682868

28692869

2870+
class SearchsortedMatcher(BaseMatcher):
2871+
def generate_code(self, kwargs):
2872+
2873+
if "side" in kwargs:
2874+
kwargs["right"] = kwargs.pop("side").strip("\n") + "== 'right'"
2875+
2876+
if "sorter" in kwargs and kwargs["sorter"] is not None:
2877+
kwargs[
2878+
"sorted_sequence"
2879+
] += ".take_along_axis(axis=-1, indices = {})".format(
2880+
kwargs.pop("sorter").strip("\n")
2881+
)
2882+
2883+
code = "paddle.searchsorted({})".format(self.kwargs_to_str(kwargs))
2884+
2885+
if "out" in kwargs and kwargs["out"] is not None:
2886+
out_v = kwargs.pop("out").strip("\n")
2887+
code = "paddle.assign(paddle.searchsorted({}), output={})".format(
2888+
self.kwargs_to_str(kwargs), out_v
2889+
)
2890+
2891+
return code
2892+
2893+
28702894
class SincMatcher(BaseMatcher):
28712895
def generate_code(self, kwargs):
28722896
if "input" not in kwargs:

tests/test_cummin.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.cummin")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
x = torch.tensor([[1.0, 1.0, 1.0],
27+
[2.0, 2.0, 2.0],
28+
[3.0, 3.0, 3.0]])
29+
result = torch.cummin(x, 0)
30+
"""
31+
)
32+
obj.run(pytorch_code, ["result"])
33+
34+
35+
def test_case_2():
36+
pytorch_code = textwrap.dedent(
37+
"""
38+
import torch
39+
x = torch.tensor([[1.0, 1.0, 1.0],
40+
[2.0, 2.0, 2.0],
41+
[3.0, 3.0, 3.0]])
42+
result = torch.cummin(x, dim=1)
43+
"""
44+
)
45+
obj.run(pytorch_code, ["result"])
46+
47+
48+
def test_case_3():
49+
pytorch_code = textwrap.dedent(
50+
"""
51+
import torch
52+
x = torch.tensor([[1.0, 1.0, 1.0],
53+
[2.0, 2.0, 2.0],
54+
[3.0, 3.0, 3.0]])
55+
result = torch.cummin(input=x, dim=1)
56+
"""
57+
)
58+
obj.run(pytorch_code, ["result"])
59+
60+
61+
def test_case_4():
62+
pytorch_code = textwrap.dedent(
63+
"""
64+
import torch
65+
x = torch.tensor([[1.0, 1.0, 1.0],
66+
[2.0, 2.0, 2.0],
67+
[3.0, 3.0, 3.0]])
68+
values = torch.tensor([[1.0, 1.0, 1.0],
69+
[2.0, 2.0, 2.0],
70+
[3.0, 3.0, 3.0]]).float()
71+
indices = torch.tensor([[1, 1, 1],
72+
[2, 2, 2],
73+
[3, 3, 3]])
74+
out = (values, indices)
75+
result = torch.cummin(x, 0, out=(values, indices))
76+
"""
77+
)
78+
obj.run(pytorch_code, ["result", "out"])
79+
80+
81+
def test_case_5():
82+
pytorch_code = textwrap.dedent(
83+
"""
84+
import torch
85+
x = torch.tensor([[1.0, 1.0, 1.0],
86+
[2.0, 2.0, 2.0],
87+
[3.0, 3.0, 3.0]])
88+
values = torch.tensor([[1.0, 1.0, 1.0],
89+
[2.0, 2.0, 2.0],
90+
[3.0, 3.0, 3.0]]).float()
91+
indices = torch.tensor([[1, 1, 1],
92+
[2, 2, 2],
93+
[3, 3, 3]])
94+
out = (values, indices)
95+
result = torch.cummin(x, dim = 0, out=(values, indices))
96+
"""
97+
)
98+
obj.run(pytorch_code, ["result", "out"])
99+
100+
101+
def test_case_6():
102+
pytorch_code = textwrap.dedent(
103+
"""
104+
import torch
105+
x = torch.tensor([[1.0, 1.0, 1.0],
106+
[2.0, 2.0, 2.0],
107+
[3.0, 3.0, 3.0]])
108+
values = torch.tensor([[1.0, 1.0, 1.0],
109+
[2.0, 2.0, 2.0],
110+
[3.0, 3.0, 3.0]]).float()
111+
indices = torch.tensor([[1, 1, 1],
112+
[2, 2, 2],
113+
[3, 3, 3]])
114+
out = (values, indices)
115+
result = torch.cummin(input = x, dim =0, out=(values, indices))
116+
"""
117+
)
118+
obj.run(pytorch_code, ["result", "out"])

tests/test_searchsorted.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.searchsorted")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
x = torch.tensor([[ 1, 3, 5, 7, 9],
27+
[ 2, 4, 6, 8, 10]])
28+
values = torch.tensor([[3, 6, 9],
29+
[3, 6, 9]])
30+
result = torch.searchsorted(x, values)
31+
"""
32+
)
33+
obj.run(pytorch_code, ["result"])
34+
35+
36+
def test_case_2():
37+
pytorch_code = textwrap.dedent(
38+
"""
39+
import torch
40+
x = torch.tensor([[ 1, 3, 5, 7, 9],
41+
[ 2, 4, 6, 8, 10]])
42+
values = torch.tensor([[3, 6, 9],
43+
[3, 6, 9]])
44+
result = torch.searchsorted(x, values, out_int32 = True)
45+
"""
46+
)
47+
obj.run(pytorch_code, ["result"])
48+
49+
50+
def test_case_3():
51+
pytorch_code = textwrap.dedent(
52+
"""
53+
import torch
54+
x = torch.tensor([[ 1, 3, 5, 7, 9],
55+
[ 2, 4, 6, 8, 10]])
56+
values = torch.tensor([[3, 6, 9],
57+
[3, 6, 9]])
58+
result = torch.searchsorted(x, values, right = True)
59+
"""
60+
)
61+
obj.run(pytorch_code, ["result"])
62+
63+
64+
def test_case_4():
65+
pytorch_code = textwrap.dedent(
66+
"""
67+
import torch
68+
x = torch.tensor([[ 1, 3, 5, 7, 9],
69+
[ 2, 4, 6, 8, 10]])
70+
values = torch.tensor([[3, 6, 9],
71+
[3, 6, 9]])
72+
result = torch.searchsorted(x, values, side = 'right')
73+
"""
74+
)
75+
obj.run(pytorch_code, ["result"])
76+
77+
78+
def test_case_5():
79+
pytorch_code = textwrap.dedent(
80+
"""
81+
import torch
82+
x = torch.tensor([[ 1, 3, 5, 7, 9],
83+
[ 2, 4, 6, 8, 10]])
84+
values = torch.tensor([[3, 6, 9],
85+
[3, 6, 9]])
86+
out = torch.tensor([[3, 6, 9],
87+
[3, 6, 9]])
88+
result = torch.searchsorted(x, values, out = out)
89+
"""
90+
)
91+
obj.run(pytorch_code, ["result"])
92+
93+
94+
def test_case_6():
95+
pytorch_code = textwrap.dedent(
96+
"""
97+
import torch
98+
x = torch.tensor([[ 1, 3, 9, 7, 5],
99+
[ 2, 4, 6, 8, 10]])
100+
values = torch.tensor([[3, 6, 9],
101+
[3, 6, 9]])
102+
sorter = torch.argsort(x)
103+
result = torch.searchsorted(x, values, sorter = sorter)
104+
"""
105+
)
106+
obj.run(pytorch_code, ["result"])
107+
108+
109+
def test_case_7():
110+
pytorch_code = textwrap.dedent(
111+
"""
112+
import torch
113+
x = torch.tensor([[ 1, 3, 9, 7, 5],
114+
[ 2, 4, 6, 8, 10]])
115+
values = torch.tensor([[3, 6, 9],
116+
[3, 6, 9]])
117+
out = torch.tensor([[3, 6, 9],
118+
[3, 6, 9]])
119+
sorter = torch.argsort(x)
120+
result = torch.searchsorted(x, values, right = True, side = 'right', out = out, sorter = sorter)
121+
"""
122+
)
123+
obj.run(pytorch_code, ["result"])
124+
125+
126+
def test_case_8():
127+
pytorch_code = textwrap.dedent(
128+
"""
129+
import torch
130+
x = torch.tensor([[ 1, 3, 5, 7, 9],
131+
[ 2, 4, 6, 8, 10]])
132+
values = torch.tensor([[3, 6, 9],
133+
[3, 6, 9]])
134+
result = torch.searchsorted(x, values, right = False, side = 'right')
135+
"""
136+
)
137+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)