1
1
/*
2
- * Copyright (c) 2017-2024 Arm Limited.
2
+ * Copyright (c) 2017-2025 Arm Limited.
3
3
*
4
4
* SPDX-License-Identifier: MIT
5
5
*
46
46
#include " tests/validation/reference/PadLayer.h"
47
47
#include " tests/validation/reference/Permute.h"
48
48
#include " tests/validation/reference/Utils.h"
49
-
49
+ # include " tests/validation/reference/DequantizationLayer.h "
50
50
#include < random>
51
51
#include < type_traits>
52
52
@@ -85,13 +85,28 @@ configure_conv_function(ConvolutionFunction &func,
85
85
#endif // ARM_COMPUTE_OPENCL_ENABLED
86
86
} // namespace detail
87
87
88
- template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
88
+ template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW, typename TO=T >
89
89
class ConvolutionValidationGenericFixture : public framework ::Fixture
90
90
{
91
91
public:
92
- using TBias = typename std::conditional < std::is_same<typename std::decay<T>::type, uint8_t >::value
93
- || std::is_same<typename std::decay<T>::type, int8_t >::value,
94
- int32_t , T >::type;
92
+ // Quantized input?
93
+ static constexpr bool T_is_q =
94
+ std::is_same<typename std::decay<T>::type, uint8_t >::value ||
95
+ std::is_same<typename std::decay<T>::type, int8_t >::value;
96
+
97
+ // Float output?
98
+ static constexpr bool TO_is_f32 =
99
+ std::is_same<typename std::decay<TO>::type, float >::value;
100
+
101
+ // Bias type:
102
+ // - Q->F32: float
103
+ // - Q->Q : int32_t
104
+ // - FP->* : T
105
+ using TBias = typename std::conditional<
106
+ (T_is_q && TO_is_f32),
107
+ float ,
108
+ typename std::conditional<T_is_q, int32_t , T>::type
109
+ >::type;
95
110
96
111
void setup_quantization (TensorShape input_shape, TensorShape weights_shape, QuantizationInfo &input_q_info,
97
112
QuantizationInfo &weights_q_info, DataType data_type)
@@ -144,14 +159,21 @@ class ConvolutionValidationGenericFixture : public framework::Fixture
144
159
_data_type = data_type;
145
160
_weights_data_type = weights_data_type;
146
161
const bool is_quantized = is_data_type_quantized (weights_data_type);
147
- _is_bfloat16 = data_type == DataType::BFLOAT16;
148
- _bias_data_type = is_quantized ? DataType::S32 : (_is_bfloat16 ? DataType::F32 : data_type);
149
- _output_data_type = _is_bfloat16 ? DataType::F32 : data_type;
162
+
163
+ _is_bfloat16 = data_type == DataType::BFLOAT16;
164
+ _output_data_type = (_is_bfloat16 || std::is_same<TO, float >::value) ? DataType::F32 : data_type;
165
+
166
+ const bool q_to_f32 = is_quantized && (_output_data_type == DataType::F32);
167
+ _bias_data_type = q_to_f32 ? DataType::F32
168
+ : (is_quantized ? DataType::S32
169
+ : (_is_bfloat16 ? DataType::F32 : data_type));
170
+
150
171
_quantization_info = quantization_info;
151
172
_weight_quantization_info = weight_quantization_info;
152
173
_data_layout = data_layout;
153
174
_dst_q_info = quantization_info;
154
175
176
+
155
177
if (is_quantized && !is_data_type_quantized_symmetric (weights_data_type) && (!act_info.enabled () || act_info.activation () == ActivationFunction::IDENTITY))
156
178
{
157
179
setup_quantization (input_shape, weights_shape, _quantization_info, _weight_quantization_info, data_type);
@@ -503,11 +525,10 @@ class ConvolutionValidationGenericFixture : public framework::Fixture
503
525
// Compute Convolution function
504
526
conv.run ();
505
527
}
506
-
507
528
return dst;
508
529
}
509
530
510
- SimpleTensor<T > compute_reference (const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
531
+ SimpleTensor<TO > compute_reference (const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
511
532
const Size2D &dilation, const ActivationLayerInfo act_info, PaddingList pre_pad_layer = PaddingList({}))
512
533
{
513
534
ARM_COMPUTE_ERROR_ON ((input_shape[2 ] % weights_shape[2 ]) != 0 );
@@ -534,19 +555,20 @@ class ConvolutionValidationGenericFixture : public framework::Fixture
534
555
regularize_values (static_cast <void *>(src.data ()), src.num_elements ());
535
556
regularize_values (static_cast <void *>(weights.data ()), weights.num_elements ());
536
557
}
537
-
538
558
if (pre_pad_layer.size () > 0 )
539
559
{
540
560
src = reference::pad_layer<T>(src, pre_pad_layer, PixelValue (0 ), PaddingMode::CONSTANT);
541
561
}
542
562
543
- return (act_info.enabled ()) ? reference::activation_layer<T >(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups, _dst_q_info),
563
+ auto res= (act_info.enabled ()) ? reference::activation_layer<TO >(reference::convolution_layer<T,TW,TBias,TO >(src, weights, bias, output_shape, info, dilation, num_groups, _dst_q_info),
544
564
act_info) :
545
- reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups, _dst_q_info);
565
+ reference::convolution_layer<T,TW,TBias,TO>(src, weights, bias, output_shape, info, dilation, num_groups, _dst_q_info);
566
+
567
+ return res;
546
568
}
547
569
548
570
TensorType _target{};
549
- SimpleTensor<T > _reference{};
571
+ SimpleTensor<TO > _reference{};
550
572
DataType _data_type{};
551
573
DataType _weights_data_type{};
552
574
DataType _bias_data_type{};
@@ -602,14 +624,14 @@ class ConvolutionValidationWithPaddingFixture : public ConvolutionValidationGene
602
624
}
603
625
};
604
626
605
- template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false >
606
- class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture <TensorType, AccessorType, FunctionType, T, T>
627
+ template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false , typename TO = T >
628
+ class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture <TensorType, AccessorType, FunctionType, T, T, TO >
607
629
{
608
630
public:
609
631
void setup (TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
610
632
DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
611
633
{
612
- ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup (input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
634
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T, TO >::setup (input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
613
635
data_type, data_type, data_layout, quantization_info, quantization_info, act_info, mixed_layout);
614
636
}
615
637
};
0 commit comments