diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc.py index 87a9a48f3..d2392cd79 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc.py @@ -166,10 +166,11 @@ def generate_tpc(default_config: OpQuantizationConfig, operator_set = [] fusing_patterns = [] + activation_quantization_config = (default_configuration_options.clone_and_edit_weight_attribute(enable_weights_quantization=False)) + operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.STACK, qc_options=activation_quantization_config)) + no_quantization_config = (default_configuration_options.clone_and_edit(enable_activation_quantization=False) .clone_and_edit_weight_attribute(enable_weights_quantization=False)) - - operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.STACK, qc_options=no_quantization_config)) operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.UNSTACK, qc_options=no_quantization_config)) operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.DROPOUT, qc_options=no_quantization_config)) operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.FLATTEN, qc_options=no_quantization_config)) diff --git a/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc_stack.py b/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc_stack.py new file mode 100644 index 000000000..4c297cd5d --- /dev/null +++ b/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc_stack.py @@ -0,0 +1,29 @@ +# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc import get_tpc as get_tpc_imx500_v1 + + +def test_tpc_stack(): + + tpc = get_tpc_imx500_v1() # only imx500 supported + assert 'Stack' in [opset.name for opset in tpc.operator_set] + + for opset in tpc.operator_set: + if opset.name == 'Stack': + for qc in opset.qc_options.quantization_configurations: + assert qc.default_weight_attr_config.enable_weights_quantization == False + assert qc.attr_weights_configs_mapping == {} + assert qc.enable_activation_quantization == True + assert qc.activation_n_bits == 8 diff --git a/tests_pytest/keras_tests/e2e_tests/test_stack.py b/tests_pytest/keras_tests/e2e_tests/test_stack.py new file mode 100644 index 000000000..c40ae4ac6 --- /dev/null +++ b/tests_pytest/keras_tests/e2e_tests/test_stack.py @@ -0,0 +1,55 @@ +# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Iterator, List +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers +import keras +import model_compression_toolkit as mct +from mct_quantizers import KerasActivationQuantizationHolder + + +def get_model(): + + inputs = keras.layers.Input((32, 32, 3)) + out1 = layers.Conv2D(16, kernel_size=3, padding='same', activation='relu')(inputs) + out2 = layers.Conv2D(16, kernel_size=3, padding='same', activation='relu')(inputs) + outputs = tf.stack([out1, out2], -1) + return keras.Model(inputs, outputs) + + +def get_representative_dataset(n_iter=1): + + def representative_dataset() -> Iterator[List]: + for _ in range(n_iter): + yield [np.random.randn(1, 32, 32, 3)] + return representative_dataset + + +def test_stack(): + + model = get_model() + tpc = mct.get_target_platform_capabilities('tensorflow', 'imx500') # only imx500 supported + q_model, _ = mct.ptq.keras_post_training_quantization(model, + get_representative_dataset(n_iter=1), + target_resource_utilization=None, + core_config=mct.core.CoreConfig(), + target_platform_capabilities=tpc) + + assert getattr(q_model.layers[-2], "function") is tf.stack + + stack_activation_holder = q_model.layers[-1] # activation holder for stack layer + assert isinstance(stack_activation_holder, KerasActivationQuantizationHolder) + assert stack_activation_holder.activation_holder_quantizer.num_bits == 8 diff --git a/tests_pytest/pytorch_tests/e2e_tests/test_stack.py b/tests_pytest/pytorch_tests/e2e_tests/test_stack.py new file mode 100644 index 000000000..43078666d --- /dev/null +++ b/tests_pytest/pytorch_tests/e2e_tests/test_stack.py @@ -0,0 +1,64 @@ +# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Iterator, List +import torch +import torch.nn as nn +import model_compression_toolkit as mct +from mct_quantizers import PytorchActivationQuantizationHolder + + +def get_model(): + + class StackModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1), + nn.ReLU() + ) + + def forward(self, x): + out1 = self.conv1(x) + out2 = self.conv2(x) + output = torch.stack([out1, out2], dim=-1) + return output + return StackModel() + + +def get_representative_dataset(n_iter=1): + + def representative_dataset() -> Iterator[List]: + for _ in range(n_iter): + yield [torch.randn(1, 3, 32, 32)] + return representative_dataset + + +def test_stack(): + + model = get_model() + tpc = mct.get_target_platform_capabilities('pytorch', 'imx500') # only imx500 supported + q_model, _ = mct.ptq.pytorch_post_training_quantization(model, + get_representative_dataset(n_iter=1), + target_resource_utilization=None, + core_config=mct.core.CoreConfig(), + target_platform_capabilities=tpc) + + assert hasattr(q_model, 'stack_activation_holder_quantizer') # activation holder for stack layer + assert isinstance(q_model.stack_activation_holder_quantizer, PytorchActivationQuantizationHolder) + assert q_model.stack_activation_holder_quantizer.activation_holder_quantizer.num_bits == 8