Skip to content

Commit f3085cd

Browse files
committed
feat: enable F32 output in CpuGemmConv2d
- Updated convolution reference to branch epilogue: * TO=float: int32 to float dequant (acc * sA * sB + bias_f32) * TO!=float: usual quantize_down_scale_by_fixedpoint with int32 bias - Changed fixture to use F32 bias tensor for Q->F32 runs (instead of S32), matching arm_gemm dequant epilogue which only supports float bias. - Added explicit template instantiations for convolution_layer with TBias=float, TO=float to fix linker errors in validation. - Disabled activation in arm_gemm dequant path: offsets are applied afterwards by CpuGemmLowpOffsetContributionKernel, so activation must run there to see the correct final accumulator. This aligns target and reference for quantized to F32 convolution tests and prevents premature clamping before offset contributions. Change-Id: I6fffc98dc0798542a2702e6a593b850c16561e3b Signed-off-by: Pablo Marquez Tello <[email protected]>
1 parent 932f767 commit f3085cd

File tree

8 files changed

+205
-83
lines changed

8 files changed

+205
-83
lines changed

src/cpu/operators/CpuGemmConv2d.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,29 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src,
287287
}
288288

289289
GEMMLowpOutputStageInfo output_info;
290-
output_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
291-
output_info.gemmlowp_offset = uoqinfo.offset;
292-
output_info.gemmlowp_min_bound = min_activation;
293-
output_info.gemmlowp_max_bound = max_activation;
294-
output_info.is_quantized_per_channel = (tmp_weights.data_type() == DataType::QSYMM8_PER_CHANNEL);
295-
quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
290+
291+
// F32 dequant path? (input quantized, output float)
292+
const bool dequantize_f32 = (dst->data_type() == DataType::F32);
293+
294+
if (dequantize_f32)
295+
{
296+
// No requant stage; offsets are handled via offset-contribution on int32
297+
output_info.type = GEMMLowpOutputStageType::NONE;
298+
output_info.gemmlowp_offset = 0;
299+
output_info.gemmlowp_min_bound = 0;
300+
output_info.gemmlowp_max_bound = 0;
301+
output_info.is_quantized_per_channel = false; // irrelevant when NONE
302+
}
303+
else
304+
{
305+
// Existing Q->Q path
306+
output_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
307+
output_info.gemmlowp_offset = uoqinfo.offset;
308+
output_info.gemmlowp_min_bound = min_activation;
309+
output_info.gemmlowp_max_bound = max_activation;
310+
output_info.is_quantized_per_channel = (tmp_weights.data_type() == DataType::QSYMM8_PER_CHANNEL);
311+
quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
312+
}
296313

297314
const GEMMInfo gemm_info =
298315
GEMMInfo(false /* is_a_reshaped */, false /* is_b_reshaped */, true /* reshape_b_only_on_first_run */,
@@ -504,9 +521,11 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
504521
}
505522

506523
const unsigned int mat_weights_cols = weights->dimension(idx_kernels);
524+
const bool dequantize_f32 = is_data_type_quantized(data_type) && dst->data_type() == DataType::F32;
507525

508526
// Create temporary GEMM output tensor in case we cannot skip col2im
509-
const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type;
527+
const DataType output_data_type = data_type == DataType::BFLOAT16 || dequantize_f32 ? DataType::F32 : data_type;
528+
510529
if (!_skip_col2im)
511530
{
512531
TensorShape shape_gemm;
@@ -725,7 +744,14 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src,
725744
{
726745
if (is_quantized)
727746
{
728-
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
747+
if (data_type == DataType::QASYMM8_SIGNED && dst->data_type() == DataType::F32)
748+
{
749+
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::F32);
750+
}
751+
else
752+
{
753+
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
754+
}
729755
}
730756
else if (is_bf16)
731757
{
@@ -776,8 +802,9 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src,
776802
gemm_input_to_use = &im2col_reshaped_info;
777803
}
778804

805+
const bool dequantize_f32 = is_data_type_quantized(data_type) && dst->data_type() == DataType::F32;
779806
// Create temporary GEMM output tensor in case we cannot skip col2im
780-
const DataType output_data_type = data_type == DataType::BFLOAT16 ? DataType::F32 : data_type;
807+
const DataType output_data_type = data_type == DataType::BFLOAT16 || dequantize_f32 ? DataType::F32 : data_type;
781808
if (!skip_col2im)
782809
{
783810
TensorShape shape_gemm = gemm_input_to_use->tensor_shape();

src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
200200
case DataType::U8:
201201
case DataType::S8:
202202
{
203-
if (is_data_type_quantized_asymmetric(a_to_use->data_type()) &&
203+
if (dst->data_type() != DataType::F32 && is_data_type_quantized_asymmetric(a_to_use->data_type()) &&
204204
info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
205205
{
206206
auto c_info_to_use = c == nullptr ? nullptr : c;

tests/datasets/SmallConvolutionLayerDataset.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2017-2021 Arm Limited.
2+
* Copyright (c) 2017-2021, 2025 Arm Limited.
33
*
44
* SPDX-License-Identifier: MIT
55
*
@@ -21,8 +21,8 @@
2121
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
* SOFTWARE.
2323
*/
24-
#ifndef ARM_COMPUTE_TEST_SMALL_CONVOLUTION_LAYER_DATASET
25-
#define ARM_COMPUTE_TEST_SMALL_CONVOLUTION_LAYER_DATASET
24+
#ifndef ACL_TESTS_DATASETS_SMALLCONVOLUTIONLAYERDATASET_H
25+
#define ACL_TESTS_DATASETS_SMALLCONVOLUTIONLAYERDATASET_H
2626

2727
#include "tests/datasets/ConvolutionLayerDataset.h"
2828

@@ -246,4 +246,4 @@ class SmallGroupedConvolutionLayerDataset final : public ConvolutionLayerDataset
246246
} // namespace datasets
247247
} // namespace test
248248
} // namespace arm_compute
249-
#endif /* ARM_COMPUTE_TEST_SMALL_CONVOLUTION_LAYER_DATASET */
249+
#endif // ACL_TESTS_DATASETS_SMALLCONVOLUTIONLAYERDATASET_H

tests/validation/NEON/ConvolutionLayer.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,10 @@ template <typename T>
13631363
using NEGEMMConvolutionLayerForUpdatedStaticQuantInfoAfterConfigureFixture = ConvolutionValidationForUpdatedStaticQuantInfoAfterConfigureFixture<Tensor, Accessor, NEGEMMConvolutionLayer, T>;
13641364
template <typename T>
13651365
using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T>;
1366+
template <typename T>
1367+
using NEGEMMConvolutionLayerQuantizedF32OutputFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T,false,float>;
1368+
1369+
13661370
template <typename T>
13671371
using NEGEMMConvolutionLayerQuantizedMixedDataLayoutFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T, true>;
13681372

@@ -1397,6 +1401,21 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerForUpdatedStaticQuantInfo
13971401
// Validate output
13981402
validate(Accessor(_target), _reference, tolerance_qasymm8);
13991403
}
1404+
1405+
FIXTURE_DATA_TEST_CASE(RunSmallDequantizeF32, NEGEMMConvolutionLayerQuantizedF32OutputFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
1406+
framework::dataset::make("ReshapeWeights", { true })),
1407+
framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
1408+
framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
1409+
framework::dataset::make("QuantizationInfoIfActivationEnabled", { QuantizationInfo(2.f / 255.f, 10) })),
1410+
framework::dataset::make("ActivationInfo", {ActivationLayerInfo()})))
1411+
{
1412+
// Validate output
1413+
validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
1414+
}
1415+
1416+
1417+
1418+
14001419
TEST_SUITE_END() // QASYMM8_SIGNED
14011420

14021421
TEST_SUITE(QASYMM8)
@@ -1425,6 +1444,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture<uint8_t>
14251444
// Validate output
14261445
validate(Accessor(_target), _reference, tolerance_qasymm8);
14271446
}
1447+
14281448
FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL,
14291449
combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
14301450
framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),

tests/validation/fixtures/ConvolutionLayerFixture.h

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2017-2024 Arm Limited.
2+
* Copyright (c) 2017-2025 Arm Limited.
33
*
44
* SPDX-License-Identifier: MIT
55
*
@@ -46,7 +46,7 @@
4646
#include "tests/validation/reference/PadLayer.h"
4747
#include "tests/validation/reference/Permute.h"
4848
#include "tests/validation/reference/Utils.h"
49-
49+
#include "tests/validation/reference/DequantizationLayer.h"
5050
#include <random>
5151
#include <type_traits>
5252

@@ -85,13 +85,28 @@ configure_conv_function(ConvolutionFunction &func,
8585
#endif // ARM_COMPUTE_OPENCL_ENABLED
8686
} // namespace detail
8787

88-
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
88+
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW, typename TO=T>
8989
class ConvolutionValidationGenericFixture : public framework::Fixture
9090
{
9191
public:
92-
using TBias = typename std::conditional < std::is_same<typename std::decay<T>::type, uint8_t>::value
93-
|| std::is_same<typename std::decay<T>::type, int8_t>::value,
94-
int32_t, T >::type;
92+
// Quantized input?
93+
static constexpr bool T_is_q =
94+
std::is_same<typename std::decay<T>::type, uint8_t>::value ||
95+
std::is_same<typename std::decay<T>::type, int8_t>::value;
96+
97+
// Float output?
98+
static constexpr bool TO_is_f32 =
99+
std::is_same<typename std::decay<TO>::type, float>::value;
100+
101+
// Bias type:
102+
// - Q->F32: float
103+
// - Q->Q : int32_t
104+
// - FP->* : T
105+
using TBias = typename std::conditional<
106+
(T_is_q && TO_is_f32),
107+
float,
108+
typename std::conditional<T_is_q, int32_t, T>::type
109+
>::type;
95110

96111
void setup_quantization(TensorShape input_shape, TensorShape weights_shape, QuantizationInfo &input_q_info,
97112
QuantizationInfo &weights_q_info, DataType data_type)
@@ -144,14 +159,21 @@ class ConvolutionValidationGenericFixture : public framework::Fixture
144159
_data_type = data_type;
145160
_weights_data_type = weights_data_type;
146161
const bool is_quantized = is_data_type_quantized(weights_data_type);
147-
_is_bfloat16 = data_type == DataType::BFLOAT16;
148-
_bias_data_type = is_quantized ? DataType::S32 : (_is_bfloat16 ? DataType::F32 : data_type);
149-
_output_data_type = _is_bfloat16 ? DataType::F32 : data_type;
162+
163+
_is_bfloat16 = data_type == DataType::BFLOAT16;
164+
_output_data_type = (_is_bfloat16 || std::is_same<TO, float>::value) ? DataType::F32 : data_type;
165+
166+
const bool q_to_f32 = is_quantized && (_output_data_type == DataType::F32);
167+
_bias_data_type = q_to_f32 ? DataType::F32
168+
: (is_quantized ? DataType::S32
169+
: (_is_bfloat16 ? DataType::F32 : data_type));
170+
150171
_quantization_info = quantization_info;
151172
_weight_quantization_info = weight_quantization_info;
152173
_data_layout = data_layout;
153174
_dst_q_info = quantization_info;
154175

176+
155177
if(is_quantized && !is_data_type_quantized_symmetric(weights_data_type) && (!act_info.enabled() || act_info.activation() == ActivationFunction::IDENTITY))
156178
{
157179
setup_quantization(input_shape, weights_shape, _quantization_info, _weight_quantization_info, data_type);
@@ -503,11 +525,10 @@ class ConvolutionValidationGenericFixture : public framework::Fixture
503525
// Compute Convolution function
504526
conv.run();
505527
}
506-
507528
return dst;
508529
}
509530

510-
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
531+
SimpleTensor<TO> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
511532
const Size2D &dilation, const ActivationLayerInfo act_info, PaddingList pre_pad_layer = PaddingList({}))
512533
{
513534
ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0);
@@ -534,19 +555,20 @@ class ConvolutionValidationGenericFixture : public framework::Fixture
534555
regularize_values(static_cast<void *>(src.data()), src.num_elements());
535556
regularize_values(static_cast<void *>(weights.data()), weights.num_elements());
536557
}
537-
538558
if(pre_pad_layer.size() > 0)
539559
{
540560
src = reference::pad_layer<T>(src, pre_pad_layer, PixelValue(0), PaddingMode::CONSTANT);
541561
}
542562

543-
return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups, _dst_q_info),
563+
auto res= (act_info.enabled()) ? reference::activation_layer<TO>(reference::convolution_layer<T,TW,TBias,TO>(src, weights, bias, output_shape, info, dilation, num_groups, _dst_q_info),
544564
act_info) :
545-
reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups, _dst_q_info);
565+
reference::convolution_layer<T,TW,TBias,TO>(src, weights, bias, output_shape, info, dilation, num_groups, _dst_q_info);
566+
567+
return res;
546568
}
547569

548570
TensorType _target{};
549-
SimpleTensor<T> _reference{};
571+
SimpleTensor<TO> _reference{};
550572
DataType _data_type{};
551573
DataType _weights_data_type{};
552574
DataType _bias_data_type{};
@@ -602,14 +624,14 @@ class ConvolutionValidationWithPaddingFixture : public ConvolutionValidationGene
602624
}
603625
};
604626

605-
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
606-
class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
627+
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false, typename TO = T>
628+
class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T, TO>
607629
{
608630
public:
609631
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
610632
DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
611633
{
612-
ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
634+
ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T, TO>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
613635
data_type, data_type, data_layout, quantization_info, quantization_info, act_info, mixed_layout);
614636
}
615637
};

0 commit comments

Comments
 (0)