Skip to content

Commit 2b1000d

Browse files
authored
转换规则 No.272/282/318/324/326 (#149)
* Add tests * Fix * Fix
1 parent c735d8f commit 2b1000d

File tree

6 files changed

+305
-1
lines changed

6 files changed

+305
-1
lines changed

paconvert/api_mapping.json

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,20 @@
10241024
"max"
10251025
]
10261026
},
1027-
"torch.Tensor.histogram": {},
1027+
"torch.Tensor.histogram": {
1028+
"Matcher": "TensorHistogramMatcher",
1029+
"paddle_api": "paddle.Tensor.histogram",
1030+
"args_list": [
1031+
"bins",
1032+
"range",
1033+
"weight",
1034+
"density"
1035+
],
1036+
"unsupport_args": [
1037+
"weight",
1038+
"density"
1039+
]
1040+
},
10281041
"torch.Tensor.hsplit": {},
10291042
"torch.Tensor.hypot": {
10301043
"Matcher": "HypotMatcher",
@@ -3186,6 +3199,19 @@
31863199
"input": "x"
31873200
}
31883201
},
3202+
"torch.corrcoef": {
3203+
"Matcher": "GenericMatcher",
3204+
"paddle_api": "paddle.linalg.corrcoef",
3205+
"args_list": [
3206+
"input"
3207+
],
3208+
"kwargs_change": {
3209+
"input": "x"
3210+
},
3211+
"paddle_default_kwargs": {
3212+
"rowvar": true
3213+
}
3214+
},
31893215
"torch.cos": {
31903216
"Matcher": "GenericMatcher",
31913217
"paddle_api": "paddle.cos",
@@ -4109,6 +4135,17 @@
41094135
"input": "x"
41104136
}
41114137
},
4138+
"torch.frexp": {
4139+
"Matcher": "GenericMatcher",
4140+
"paddle_api": "paddle.frexp",
4141+
"args_list": [
4142+
"input",
4143+
"out"
4144+
],
4145+
"kwargs_change": {
4146+
"input": "x"
4147+
}
4148+
},
41124149
"torch.from_numpy": {
41134150
"Matcher": "GenericMatcher",
41144151
"paddle_api": "paddle.to_tensor",
@@ -7424,6 +7461,17 @@
74247461
"input": "x"
74257462
}
74267463
},
7464+
"torch.nn.functional.mish": {
7465+
"Matcher": "GenericMatcher",
7466+
"paddle_api": "paddle.nn.functional.mish",
7467+
"args_list": [
7468+
"input",
7469+
"inplace"
7470+
],
7471+
"kwargs_change": {
7472+
"input": "x"
7473+
}
7474+
},
74277475
"torch.nn.functional.mse_loss": {
74287476
"Matcher": "SizeAverageMatcher",
74297477
"paddle_api": "paddle.nn.functional.mse_loss",
@@ -8585,6 +8633,19 @@
85858633
"input": "x"
85868634
}
85878635
},
8636+
"torch.special.log_softmax": {
8637+
"Matcher": "GenericMatcher",
8638+
"paddle_api": "paddle.nn.functional.log_softmax",
8639+
"args_list": [
8640+
"input",
8641+
"dim",
8642+
"dtype"
8643+
],
8644+
"kwargs_change": {
8645+
"input": "x",
8646+
"dim": "axis"
8647+
}
8648+
},
85888649
"torch.special.logsumexp": {
85898650
"Matcher": "LogsumexpMatcher",
85908651
"paddle_api": "paddle.logsumexp",

paconvert/api_matcher.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2978,6 +2978,15 @@ def generate_code(self, kwargs):
29782978
return code
29792979

29802980

2981+
class TensorHistogramMatcher(BaseMatcher):
2982+
def generate_code(self, kwargs):
2983+
if "range" in kwargs:
2984+
kwargs["min"] = "int({}[0])".format(kwargs["range"])
2985+
kwargs["max"] = "int({}[1])".format(kwargs["range"])
2986+
del kwargs["range"]
2987+
return GenericMatcher.generate_code(self, kwargs)
2988+
2989+
29812990
class SpecialNdtriMatcher(BaseMatcher):
29822991
def generate_code(self, kwargs):
29832992

tests/test_Tensor_histogram.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.Tensor.histogram")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
result = torch.tensor([[1., 2, 1]]).histogram(bins=4, range=(0., 3.))
27+
if hasattr(result, "hist"):
28+
result = result.hist
29+
result = result.to(torch.float32)
30+
"""
31+
)
32+
obj.run(pytorch_code, ["result"])
33+
34+
35+
def test_case_2():
36+
pytorch_code = textwrap.dedent(
37+
"""
38+
import torch
39+
input = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
40+
result = input.histogram(bins=4, range=(0., 3.))
41+
if hasattr(result, "hist"):
42+
result = result.hist
43+
result = result.to(torch.float32)
44+
"""
45+
)
46+
obj.run(pytorch_code, ["result"])
47+
48+
49+
def test_case_3():
50+
pytorch_code = textwrap.dedent(
51+
"""
52+
import torch
53+
result = torch.tensor([[1., 2, 1]]).histogram()
54+
if hasattr(result, "hist"):
55+
result = result.hist
56+
result = result.to(torch.float32)
57+
"""
58+
)
59+
obj.run(pytorch_code, ["result"])

tests/test_corrcoef.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.corrcoef")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516],
27+
[-0.1383, 1.5706, 0.4724, 0.4141],
28+
[ 0.1193, 0.2829, 0.9037, 0.3957],
29+
[-0.8202, -0.6474, -0.1631, -0.6543]])
30+
result = torch.corrcoef(x)
31+
"""
32+
)
33+
obj.run(pytorch_code, ["result"])
34+
35+
36+
def test_case_2():
37+
pytorch_code = textwrap.dedent(
38+
"""
39+
import torch
40+
x = torch.tensor([[-0.1533, 2.3020, -0.1771, 0.5928],
41+
[ 0.4338, -0.6537, 0.2296, 0.5946],
42+
[-0.4932, 1.8386, -0.1039, 1.0440],
43+
[ 0.1735, -0.8303, -0.3821, -0.4384],
44+
[-0.1533, 2.3020, -0.1771, 0.5928],
45+
[ 0.4338, -0.6537, 0.2296, 0.5946],
46+
[-0.4932, 1.8386, -0.1039, 1.0440],
47+
[ 0.1735, -0.8303, -0.3821, -0.4384]])
48+
result = torch.corrcoef(x)
49+
"""
50+
)
51+
obj.run(pytorch_code, ["result"])
52+
53+
54+
def test_case_3():
55+
pytorch_code = textwrap.dedent(
56+
"""
57+
import torch
58+
x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516],
59+
[-0.1383, 1.5706, 0.4724, 0.4141],
60+
[ 0.1193, 0.2829, 0.9037, 0.3957],
61+
[-0.8202, -0.6474, -0.1631, -0.6543]])
62+
result = torch.corrcoef(input=x)
63+
"""
64+
)
65+
obj.run(pytorch_code, ["result"])

tests/test_nn_functional_mish.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.nn.functional.mish")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
import torch.nn.functional as F
27+
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
28+
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
29+
result = F.mish(x)
30+
"""
31+
)
32+
obj.run(pytorch_code, ["result"])
33+
34+
35+
def test_case_2():
36+
pytorch_code = textwrap.dedent(
37+
"""
38+
import torch
39+
import torch.nn.functional as F
40+
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
41+
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
42+
result = F.mish(input=x)
43+
"""
44+
)
45+
obj.run(pytorch_code, ["result"])
46+
47+
48+
def test_case_3():
49+
pytorch_code = textwrap.dedent(
50+
"""
51+
import torch
52+
import torch.nn.functional as F
53+
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
54+
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
55+
result = F.mish(input=x, inplace=False)
56+
"""
57+
)
58+
obj.run(pytorch_code, ["result"])

tests/test_special_log_softmax.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.special.log_softmax")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
input = torch.tensor([1.4907, 1.0593, 1.5696])
27+
result = torch.special.log_softmax(input, 0)
28+
"""
29+
)
30+
obj.run(pytorch_code, ["result"])
31+
32+
33+
def test_case_2():
34+
pytorch_code = textwrap.dedent(
35+
"""
36+
import torch
37+
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
38+
result = torch.special.log_softmax(input, dim=1)
39+
"""
40+
)
41+
obj.run(pytorch_code, ["result"])
42+
43+
44+
def test_case_3():
45+
pytorch_code = textwrap.dedent(
46+
"""
47+
import torch
48+
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
49+
result = torch.special.log_softmax(input, 1, dtype=torch.float32)
50+
"""
51+
)
52+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)