-
Couldn't load subscription status.
- Fork 76
Fix TPC for stack layer #1519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kawakami-masaki0
merged 8 commits into
SonySemiconductorSolutions:feature_stack
from
kawakami-masaki0:add_tpc_stack
Oct 2, 2025
Merged
Fix TPC for stack layer #1519
Changes from 4 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
c2e316e
Fix TPC for stack
kawakami-masaki0 59bf9ec
Fix test check
kawakami-masaki0 7784e4a
Add check for tf.stack
kawakami-masaki0 b401d6b
Fix qc_options
kawakami-masaki0 7440354
Add line breaks and comments
kawakami-masaki0 4b3b914
Formatting
kawakami-masaki0 1905f4c
Changed check for operator_set
kawakami-masaki0 4671aa6
Add qc_options for stack layer
kawakami-masaki0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc_stack.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc import get_tpc as get_tpc_imx500_v1 | ||
|
|
||
| def test_tpc_stack(): | ||
|
|
||
| tpc = get_tpc_imx500_v1() # only imx500 supported | ||
| opset = tpc.operator_set[0] | ||
yt0705 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert opset.name == 'Stack' | ||
| assert opset.qc_options is None | ||
|
|
||
| qc = tpc.default_qco.quantization_configurations[0] # Stack layer applies default qc_options | ||
| assert qc.default_weight_attr_config.enable_weights_quantization == False | ||
| assert qc.attr_weights_configs_mapping == {} | ||
| assert qc.enable_activation_quantization == True | ||
| assert qc.activation_n_bits == 8 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| from typing import Iterator, List | ||
| import numpy as np | ||
| import tensorflow as tf | ||
| from tensorflow.keras import layers | ||
| import keras | ||
| import model_compression_toolkit as mct | ||
| from mct_quantizers import KerasActivationQuantizationHolder | ||
|
|
||
|
|
||
| def get_model(): | ||
| inputs = keras.layers.Input((32, 32, 3)) | ||
| out1 = layers.Conv2D(16, kernel_size=3, padding='same', activation='relu')(inputs) | ||
| out2 = layers.Conv2D(16, kernel_size=3, padding='same', activation='relu')(inputs) | ||
| outputs = tf.stack([out1, out2], -1) | ||
| return keras.Model(inputs, outputs) | ||
|
|
||
|
|
||
| def get_representative_dataset(n_iter=1): | ||
| def representative_dataset() -> Iterator[List]: | ||
| for _ in range(n_iter): | ||
| yield [np.random.randn(1, 32, 32, 3)] | ||
| return representative_dataset | ||
|
|
||
|
|
||
| def test_stack(): | ||
|
|
||
| model = get_model() | ||
| tpc = mct.get_target_platform_capabilities('tensorflow', 'imx500') # only imx500 supported | ||
| q_model, _ = mct.ptq.keras_post_training_quantization(model, | ||
| get_representative_dataset(n_iter=1), | ||
| target_resource_utilization=None, | ||
| core_config=mct.core.CoreConfig(), | ||
| target_platform_capabilities=tpc) | ||
|
|
||
| assert getattr(q_model.layers[-2], "function") is tf.stack | ||
|
|
||
| stack_activation_holder = q_model.layers[-1] # activation holder for stack layer | ||
| assert isinstance(stack_activation_holder, KerasActivationQuantizationHolder) | ||
| assert stack_activation_holder.activation_holder_quantizer.num_bits == 8 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| from typing import Iterator, List | ||
| import torch | ||
| import torch.nn as nn | ||
| import model_compression_toolkit as mct | ||
| from mct_quantizers import PytorchActivationQuantizationHolder | ||
|
|
||
|
|
||
| def get_model(): | ||
|
|
||
| class StackModel(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.conv1 = nn.Sequential( | ||
| nn.Conv2d(3, 16, kernel_size=3, padding=1), | ||
| nn.ReLU() | ||
| ) | ||
| self.conv2 = nn.Sequential( | ||
| nn.Conv2d(3, 16, kernel_size=3, padding=1), | ||
| nn.ReLU() | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| out1 = self.conv1(x) | ||
| out2 = self.conv2(x) | ||
| output = torch.stack([out1, out2], dim=-1) | ||
| return output | ||
| return StackModel() | ||
|
|
||
|
|
||
| def get_representative_dataset(n_iter=1): | ||
| def representative_dataset() -> Iterator[List]: | ||
| for _ in range(n_iter): | ||
| yield [torch.randn(1, 3, 32, 32)] | ||
| return representative_dataset | ||
|
|
||
|
|
||
| def test_stack(): | ||
|
|
||
| model = get_model() | ||
| tpc = mct.get_target_platform_capabilities('pytorch', 'imx500') # only imx500 supported | ||
| q_model, _ = mct.ptq.pytorch_post_training_quantization(model, | ||
| get_representative_dataset(n_iter=1), | ||
| target_resource_utilization=None, | ||
| core_config=mct.core.CoreConfig(), | ||
| target_platform_capabilities=tpc) | ||
|
|
||
| assert hasattr(q_model, 'stack_activation_holder_quantizer') # activation holder for stack layer | ||
| assert isinstance(q_model.stack_activation_holder_quantizer, PytorchActivationQuantizationHolder) | ||
| assert q_model.stack_activation_holder_quantizer.activation_holder_quantizer.num_bits == 8 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.