Skip to content

Commit fafdd56

Browse files
authored
Upsample (#115)
* add nn.Upsample,CUDAExtension,CppExtension,SequentialSampler,is_sparse * fix cpp_extension ut * fix cpp_extension ut * fix UtilsCppExtensionMatcher ci * fix UtilsCppExtensionMatcher * fix cpp_extension ut * UtilsCppExtensionMatcher bug * UtilsCppExtensionMatcher bug * fix UtilsCppExtensionMatcher pop * fix cpp test * add cpp test * add Attribute2Func
1 parent 748524a commit fafdd56

8 files changed

+420
-1
lines changed

paconvert/api_mapping.json

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6732,6 +6732,19 @@
67326732
"unflattened_size": "shape"
67336733
}
67346734
},
6735+
"torch.nn.Upsample": {
6736+
"Matcher": "GenericMatcher",
6737+
"paddle_api": "paddle.nn.Upsample",
6738+
"args_list": [
6739+
"size",
6740+
"scale_factor",
6741+
"mode",
6742+
"align_corners"
6743+
],
6744+
"unsupport_args": [
6745+
"recompute_scale_factor"
6746+
]
6747+
},
67356748
"torch.nn.UpsamplingBilinear2d": {
67366749
"Matcher": "UpsampleMatcher",
67376750
"paddle_api": "paddle.nn.UpsamplingBilinear2D",
@@ -9147,10 +9160,32 @@
91479160
"Matcher": "GenericMatcher",
91489161
"paddle_api": "paddle.utils.cpp_extension.BuildExtension.with_options"
91499162
},
9163+
"torch.utils.cpp_extension.CUDAExtension": {
9164+
"Matcher": "GenericMatcher",
9165+
"paddle_api": "paddle.utils.cpp_extension.CUDAExtension",
9166+
"args_list": [
9167+
"name",
9168+
"sources"
9169+
],
9170+
"kwargs_change": {
9171+
"name": ""
9172+
}
9173+
},
91509174
"torch.utils.cpp_extension.CUDA_HOME": {
91519175
"Matcher": "GenericMatcher",
91529176
"paddle_api": "paddle.utils.cpp_extension.cpp_extension.CUDA_HOME"
91539177
},
9178+
"torch.utils.cpp_extension.CppExtension": {
9179+
"Matcher": "GenericMatcher",
9180+
"paddle_api": "paddle.utils.cpp_extension.CppExtension",
9181+
"args_list": [
9182+
"name",
9183+
"sources"
9184+
],
9185+
"kwargs_change": {
9186+
"name": ""
9187+
}
9188+
},
91549189
"torch.utils.data.BatchSampler": {
91559190
"Matcher": "TorchUtilDataBatchSampler",
91569191
"args_list": [
@@ -9203,6 +9238,13 @@
92039238
"data_source"
92049239
]
92059240
},
9241+
"torch.utils.data.SequentialSampler": {
9242+
"Matcher": "GenericMatcher",
9243+
"paddle_api": "paddle.io.SequenceSampler",
9244+
"args_list": [
9245+
"data_source"
9246+
]
9247+
},
92069248
"torch.utils.data.default_collate": {
92079249
"Matcher": "GenericMatcher",
92089250
"paddle_api": "paddle.io.dataloader.collate.default_collate_fn",

paconvert/api_matcher.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3572,6 +3572,13 @@ def generate_code(self, kwargs):
35723572
return GenericMatcher.generate_code(self, kwargs)
35733573

35743574

3575+
class Attribute2Func(BaseMatcher):
3576+
def get_paddle_class_attribute_nodes(self, node):
3577+
self.parse_func(node)
3578+
code = "{}()".format(self.paddle_api)
3579+
return ast.parse(code).body[0].value
3580+
3581+
35753582
class LuMatcher(BaseMatcher):
35763583
def generate_code(self, kwargs):
35773584
out_v = kwargs.pop("out") if "out" in kwargs else None

paconvert/attribute_mapping.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
},
2323
"torch.Tensor.is_meta": {},
2424
"torch.Tensor.is_quantized": {},
25-
"torch.Tensor.is_sparse": {},
25+
"torch.Tensor.is_sparse": {
26+
"Matcher": "Attribute2Func",
27+
"paddle_api": "paddle.Tensor.is_sparse"
28+
},
2629
"torch.Tensor.mH": {},
2730
"torch.Tensor.mT": {},
2831
"torch.Tensor.names": {},

tests/test_Tensor_is_sparse.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import textwrap
15+
16+
from apibase import APIBase
17+
18+
obj = APIBase("torch.Tensor.is_sparse")
19+
20+
21+
def test_case_1():
22+
pytorch_code = textwrap.dedent(
23+
"""
24+
import torch
25+
a = torch.tensor([[ 0.9254, -0.6213]])
26+
result = a.is_sparse
27+
"""
28+
)
29+
obj.run(pytorch_code, ["result"])

tests/test_nn_Upsample.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.nn.Upsample")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
27+
[-1.2533, -0.9829, -1.0981],
28+
[ 0.1507, -1.1431, -2.0361]],
29+
30+
[[ 0.1024, -0.4482, 0.4137],
31+
[ 0.9385, 0.4565, 0.7702],
32+
[ 0.4135, -0.2587, 0.0482]]]])
33+
m = torch.nn.Upsample(scale_factor=2, mode='nearest')
34+
result = m(input)
35+
"""
36+
)
37+
obj.run(pytorch_code, ["result"])
38+
39+
40+
def test_case_2():
41+
pytorch_code = textwrap.dedent(
42+
"""
43+
import torch
44+
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
45+
[-1.2533, -0.9829, -1.0981],
46+
[ 0.1507, -1.1431, -2.0361]],
47+
48+
[[ 0.1024, -0.4482, 0.4137],
49+
[ 0.9385, 0.4565, 0.7702],
50+
[ 0.4135, -0.2587, 0.0482]]]])
51+
m = torch.nn.Upsample(scale_factor=2, mode='bilinear')
52+
result = m(input)
53+
"""
54+
)
55+
obj.run(pytorch_code, ["result"])
56+
57+
58+
def test_case_3():
59+
pytorch_code = textwrap.dedent(
60+
"""
61+
import torch
62+
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
63+
[-1.2533, -0.9829, -1.0981],
64+
[ 0.1507, -1.1431, -2.0361]],
65+
66+
[[ 0.1024, -0.4482, 0.4137],
67+
[ 0.9385, 0.4565, 0.7702],
68+
[ 0.4135, -0.2587, 0.0482]]]])
69+
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',align_corners=True)
70+
result = m(input)
71+
"""
72+
)
73+
obj.run(pytorch_code, ["result"])
74+
75+
76+
def test_case_4():
77+
pytorch_code = textwrap.dedent(
78+
"""
79+
import torch
80+
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
81+
[-1.2533, -0.9829, -1.0981],
82+
[ 0.1507, -1.1431, -2.0361]],
83+
84+
[[ 0.1024, -0.4482, 0.4137],
85+
[ 0.9385, 0.4565, 0.7702],
86+
[ 0.4135, -0.2587, 0.0482]]]])
87+
m = torch.nn.Upsample(size=(2,2))
88+
result = m(input)
89+
"""
90+
)
91+
obj.run(pytorch_code, ["result"])
92+
93+
94+
def test_case_5():
95+
pytorch_code = textwrap.dedent(
96+
"""
97+
import torch
98+
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
99+
[-1.2533, -0.9829, -1.0981],
100+
[ 0.1507, -1.1431, -2.0361]],
101+
102+
[[ 0.1024, -0.4482, 0.4137],
103+
[ 0.9385, 0.4565, 0.7702],
104+
[ 0.4135, -0.2587, 0.0482]]]])
105+
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False)
106+
result = m(input)
107+
"""
108+
)
109+
obj.run(pytorch_code, ["result"])
110+
111+
112+
def test_case_6():
113+
pytorch_code = textwrap.dedent(
114+
"""
115+
import torch
116+
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
117+
[-1.2533, -0.9829, -1.0981],
118+
[ 0.1507, -1.1431, -2.0361]],
119+
120+
[[ 0.1024, -0.4482, 0.4137],
121+
[ 0.9385, 0.4565, 0.7702],
122+
[ 0.4135, -0.2587, 0.0482]]]])
123+
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',recompute_scale_factor=True)
124+
result = m(input)
125+
"""
126+
)
127+
obj.run(
128+
pytorch_code, unsupport=True, reason="paddle unsupport recompute_scale_factor "
129+
)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.utils.cpp_extension.CUDAExtension")
20+
21+
22+
# The cuda compile not supports
23+
def test_case_1():
24+
pytorch_code = textwrap.dedent(
25+
"""
26+
from torch.utils.cpp_extension import CUDAExtension
27+
28+
CUDAExtension(
29+
name='cuda_extension',
30+
sources=['extension.cpp', 'extension_kernel.cu'],
31+
extra_compile_args={'cxx': ['-g'],
32+
'nvcc': ['-O2']})
33+
result = True
34+
"""
35+
)
36+
obj.run(pytorch_code, ["result"])
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.utils.cpp_extension.CppExtension")
20+
21+
22+
# The cpp compile not supports
23+
def test_case_1():
24+
pytorch_code = textwrap.dedent(
25+
"""
26+
from torch.utils.cpp_extension import CppExtension
27+
28+
CppExtension(
29+
name='cuda_extension',
30+
sources=['extension.cpp'],
31+
extra_compile_args=['-g'])
32+
result = True
33+
"""
34+
)
35+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)