Skip to content

Commit abd8f71

Browse files
Attach MulticlassNMSOBB (#1508)
* Attach MulticlassNMSOBB in edge-mdt-cl-dev, and add unittests. (#1507) * Updated repository URLs in README.md. (#1507) * **Temporary workaround**: Change requirements.txt from edge-mdt-cl to edge-mdt-cl-dev (temporary) (#1506)
1 parent 7a5fe9b commit abd8f71

File tree

5 files changed

+241
-5
lines changed

5 files changed

+241
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Pip install the model compression toolkit package in a Python>=3.9 environment w
3434
```
3535
pip install model-compression-toolkit
3636
```
37-
For installing the nightly version or installing from source, refer to the [installation guide](https://github.com/sony/model_optimization/blob/main/INSTALLATION.md).
37+
For installing the nightly version or installing from source, refer to the [installation guide](https://github.com/SonySemiconductorSolutions/mct-model-optimization/blob/main/INSTALLATION.md).
3838

3939
**Important note**: In order to use MCT, you’ll need to provide a pre-trained floating point model (PyTorch/Keras) as an input.
4040

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
1+
# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -32,7 +32,7 @@
3232
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
3333
AttachTpcToFramework
3434
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
35-
from edgemdt_cl.pytorch import MulticlassNMS, MulticlassNMSWithIndices
35+
from edgemdt_cl.pytorch import MulticlassNMS, MulticlassNMSWithIndices, MulticlassNMSOBB
3636

3737

3838
class AttachTpcToPytorch(AttachTpcToFramework):
@@ -98,7 +98,7 @@ def __init__(self):
9898
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
9999
Eq('p', 2) | Eq('p', None))],
100100
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
101-
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
101+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices, MulticlassNMSOBB],
102102
OperatorSetNames.EXP: [torch.exp],
103103
OperatorSetNames.SIN: [torch.sin],
104104
OperatorSetNames.COS: [torch.cos],

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ scipy
1111
protobuf
1212
mct-quantizers==1.6.0
1313
pydantic>=2.0
14-
edge-mdt-cl>=1.0
14+
edge-mdt-cl-dev
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import os
17+
import onnx
18+
import onnxruntime as ort
19+
import numpy as np
20+
import torch
21+
import torch.nn as nn
22+
import torch.nn.functional as F
23+
from typing import Iterator, List
24+
import model_compression_toolkit as mct
25+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
26+
AttributeQuantizationConfig, Signedness
27+
from tests.common_tests.helpers.tpcs_for_tests.v4.tpc import generate_tpc
28+
from mct_quantizers import QuantizationMethod
29+
from edgemdt_cl.pytorch.nms_obb import MulticlassNMSOBB, NMSOBBResults
30+
from edgemdt_cl.pytorch import load_custom_ops
31+
32+
33+
def get_representative_dataset(n_iter: int):
34+
def representative_dataset() -> Iterator[List]:
35+
for _ in range(n_iter):
36+
yield [torch.rand(1, 3, 64, 64)]
37+
38+
return representative_dataset
39+
40+
41+
def get_tpc():
42+
43+
att_cfg_noquant = AttributeQuantizationConfig()
44+
op_cfg = OpQuantizationConfig(default_weight_attr_config=att_cfg_noquant,
45+
attr_weights_configs_mapping={},
46+
activation_quantization_method=QuantizationMethod.UNIFORM,
47+
activation_n_bits=8,
48+
supported_input_activation_n_bits=8,
49+
enable_activation_quantization=False,
50+
quantization_preserving=False,
51+
fixed_scale=None,
52+
fixed_zero_point=None,
53+
simd_size=32,
54+
signedness=Signedness.AUTO)
55+
56+
tpc = generate_tpc(default_config=op_cfg, base_config=op_cfg, mixed_precision_cfg_list=[op_cfg], name="test_tpc")
57+
return tpc
58+
59+
60+
class NMSOBBModel(nn.Module):
61+
62+
def __init__(self, num_classes=2, max_detections=300, score_threshold=0.001, iou_threshold=0.7):
63+
64+
super().__init__()
65+
self.max_detections = max_detections
66+
67+
self.backbone = nn.Sequential(
68+
nn.Conv2d(3, 16, kernel_size=3, padding=1),
69+
nn.ReLU(),
70+
nn.MaxPool2d(2, 2))
71+
72+
self.bbox_reg = nn.Conv2d(16, 4 * max_detections, kernel_size=1)
73+
self.class_reg = nn.Conv2d(16, num_classes * max_detections, kernel_size=1)
74+
self.angle_reg = nn.Conv2d(16, max_detections, kernel_size=1)
75+
self.multiclass_nms_obb = MulticlassNMSOBB(score_threshold, iou_threshold, max_detections)
76+
77+
def forward(self, x):
78+
79+
batch = x.size(0)
80+
features = self.backbone(x)
81+
H_prime, W_prime = features.shape[2], features.shape[3]
82+
83+
boxes = self.bbox_reg(features)
84+
boxes = boxes.view(batch, self.max_detections, 4, H_prime * W_prime).mean(dim=3)
85+
scores = self.class_reg(features).view(batch, self.max_detections, -1, H_prime * W_prime)
86+
scores = F.softmax(scores.mean(dim=3), dim=2)
87+
angles = self.angle_reg(features)
88+
angles = angles.view(batch, self.max_detections, 1, H_prime * W_prime).mean(dim=3)
89+
90+
nms_res = self.multiclass_nms_obb(boxes, scores, angles)
91+
return nms_res
92+
93+
94+
class TestMulticlassNMSOBB():
95+
96+
def test_multiclass_nms_obb(self):
97+
98+
max_detections = 300
99+
score_threshold = 0.001
100+
iou_threshold = 0.7
101+
102+
model = NMSOBBModel(num_classes=2, max_detections=max_detections, score_threshold=score_threshold, iou_threshold=iou_threshold)
103+
104+
tpc = get_tpc()
105+
q_model, _ = mct.ptq.pytorch_post_training_quantization(model,
106+
get_representative_dataset(n_iter=1),
107+
target_resource_utilization=None,
108+
core_config=mct.core.CoreConfig(),
109+
target_platform_capabilities=tpc)
110+
111+
_, last_layer = list(q_model.named_children())[-1]
112+
113+
assert isinstance(last_layer, MulticlassNMSOBB)
114+
assert last_layer.score_threshold == score_threshold
115+
assert last_layer.iou_threshold == iou_threshold
116+
assert last_layer.max_detections == max_detections
117+
118+
dummy_x = torch.rand(1, 3, 64, 64)
119+
res = q_model(dummy_x)
120+
assert isinstance(res, NMSOBBResults)
121+
assert res.boxes.shape == (1, max_detections, 4) # boxes
122+
assert res.scores.shape == (1, max_detections) # scores
123+
assert res.labels.shape == (1, max_detections) # labels
124+
assert res.angles.shape == (1, max_detections) # angles
125+
assert res.n_valid.shape == (1, 1) # n_valid
126+
127+
# export onnx
128+
onnx_model_path = './qmodel_with_nms_obb.onnx'
129+
mct.exporter.pytorch_export_model(model=q_model,
130+
save_model_path=onnx_model_path,
131+
repr_dataset=get_representative_dataset(n_iter=1))
132+
assert os.path.exists(onnx_model_path) == True
133+
134+
# load onnx
135+
onnx_model = onnx.load(onnx_model_path)
136+
onnx.checker.check_model(onnx_model, full_check=True)
137+
opset_info = list(onnx_model.opset_import)[1]
138+
assert opset_info.domain == 'EdgeMDT' and opset_info.version == 1
139+
140+
nms_obb_node = list(onnx_model.graph.node)[-1]
141+
assert nms_obb_node.domain == 'EdgeMDT'
142+
assert nms_obb_node.op_type == 'MultiClassNMSOBB'
143+
assert len(nms_obb_node.input) == 3
144+
assert len(nms_obb_node.output) == 5
145+
146+
attrs = sorted(nms_obb_node.attribute, key=lambda a: a.name)
147+
assert attrs[0].name == 'iou_threshold'
148+
np.isclose(attrs[0].f, iou_threshold)
149+
assert attrs[1].name == 'max_detections'
150+
assert attrs[1].i == max_detections
151+
assert attrs[2].name == 'score_threshold'
152+
np.isclose(attrs[2].f, score_threshold)
153+
154+
# check for ort
155+
so = load_custom_ops()
156+
session = ort.InferenceSession(onnx_model_path, sess_options=so)
157+
ort_res = session.run(output_names=None, input_feed={'input': dummy_x.numpy()})
158+
159+
assert ort_res[0].shape == (1, max_detections, 4) # boxes
160+
assert ort_res[1].shape == (1, max_detections) # scores
161+
assert ort_res[2].shape == (1, max_detections) # labels
162+
assert ort_res[3].shape == (1, max_detections) # angles
163+
assert ort_res[4].shape == (1, 1) # n_valid
164+
165+
for i in range(len(res)):
166+
assert np.allclose(res[i].detach().numpy(), ort_res[i])
167+
168+
# delete onnx model
169+
if os.path.exists(onnx_model_path):
170+
os.remove(onnx_model_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities
17+
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import AttachTpcToPytorch
18+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
19+
AttributeQuantizationConfig, Signedness
20+
from tests.common_tests.helpers.tpcs_for_tests.v4.tpc import generate_tpc
21+
from mct_quantizers import QuantizationMethod
22+
from edgemdt_cl.pytorch.nms_obb import MulticlassNMSOBB
23+
24+
25+
def get_tpc():
26+
"""
27+
Create a target platform capabilities (TPC) configuration with no weight and activation quantization.
28+
29+
Returns a TPC object for quantization tests.
30+
"""
31+
att_cfg_noquant = AttributeQuantizationConfig()
32+
33+
op_cfg = OpQuantizationConfig(default_weight_attr_config=att_cfg_noquant,
34+
attr_weights_configs_mapping={},
35+
activation_quantization_method=QuantizationMethod.UNIFORM,
36+
activation_n_bits=8,
37+
supported_input_activation_n_bits=2,
38+
enable_activation_quantization=False,
39+
quantization_preserving=False,
40+
fixed_scale=None,
41+
fixed_zero_point=None,
42+
simd_size=32,
43+
signedness=Signedness.AUTO)
44+
45+
tpc = generate_tpc(default_config=op_cfg, base_config=op_cfg, mixed_precision_cfg_list=[op_cfg], name="test_tpc")
46+
47+
return tpc
48+
49+
50+
def test_attach2pytorch_nms_obb_tpc():
51+
52+
tpc = get_tpc()
53+
tpc = load_target_platform_capabilities(tpc)
54+
55+
attach2pytorch = AttachTpcToPytorch()
56+
fqc = attach2pytorch.attach(tpc)
57+
58+
assert MulticlassNMSOBB in attach2pytorch._opset2layer['CombinedNonMaxSuppression']
59+
60+
qc = fqc.layer2qco[MulticlassNMSOBB].quantization_configurations[0]
61+
62+
assert qc.default_weight_attr_config.enable_weights_quantization == False
63+
assert qc.default_weight_attr_config.weights_n_bits == 32
64+
assert qc.attr_weights_configs_mapping == {}
65+
assert qc.enable_activation_quantization == False
66+
assert qc.activation_n_bits == 8

0 commit comments

Comments
 (0)