Skip to content

Commit d6a8cba

Browse files
YUNQIUGUOrachguo
andauthored
[NNAPI QDQ] Add nnapi qdq softmax op support (#10591)
* wip * save * update pr comments * update Co-authored-by: rachguo <[email protected]>
1 parent 4d3cd2f commit d6a8cba

File tree

7 files changed

+152
-24
lines changed

7 files changed

+152
-24
lines changed

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { retur
3434
{"Resize", {}}}; }
3535

3636
static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { return {{"AveragePool", {}},
37+
{"Softmax", {}},
3738
{"LeakyRelu", {}}}; }
3839
static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}},
3940
{"Mul", {}}}; }

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit) {
8282
return QuantizedOpType::QDQTranspose;
8383
else if (op_type == "Reshape")
8484
return QuantizedOpType::QDQReshape;
85+
else if (op_type == "Softmax")
86+
return QuantizedOpType::QDQSoftmax;
8587
} else {
8688
// throw?
8789
}

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ enum class QuantizedOpType : uint8_t {
9292
QDQMul,
9393
QDQTranspose,
9494
QDQReshape,
95+
QDQSoftmax,
9596
// TODO, add other QDQ NodeUnit types
9697
};
9798

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1478,10 +1478,25 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
14781478
#pragma region op_softmax
14791479

14801480
class SoftMaxOpBuilder : public BaseOpBuilder {
1481+
public:
1482+
void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
1483+
14811484
private:
14821485
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
1486+
bool IsQuantizedOp(const NodeUnit& node_unit) const override;
14831487
};
14841488

1489+
bool SoftMaxOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const {
1490+
return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQSoftmax;
1491+
}
1492+
1493+
void SoftMaxOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
1494+
if (IsQuantizedOp(node_unit)) {
1495+
AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[0].quant_param); // x_scale, x_zp
1496+
AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp
1497+
}
1498+
}
1499+
14851500
Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
14861501
auto& shaper(model_builder.GetShaper());
14871502
const auto& operand_indices(model_builder.GetOperandIndices());
@@ -1499,6 +1514,21 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
14991514

15001515
int32_t axis = helper.Get("axis", 1);
15011516

1517+
// Check if the quantization scale and ZP are correct
1518+
float x_scale = 0.0f;
1519+
int32_t x_zero_point = 0;
1520+
float y_scale = 0.0f;
1521+
int32_t y_zero_point = 0;
1522+
if (IsQuantizedOp(node_unit)) {
1523+
ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint(
1524+
model_builder.GetInitializerTensors(), node_unit.Inputs()[0], node_unit.ModelPath(),
1525+
x_scale, x_zero_point));
1526+
1527+
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point));
1528+
1529+
y_scale = 1.f / 256;
1530+
}
1531+
15021532
const auto& output = node_unit.Outputs()[0].node_arg.Name();
15031533
float beta = 1.f;
15041534
std::vector<uint32_t> input_indices;
@@ -1511,7 +1541,7 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
15111541
}
15121542

15131543
ORT_RETURN_IF_ERROR(shaper.Identity(input, output));
1514-
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
1544+
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], y_scale, y_zero_point);
15151545
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices,
15161546
{output}, {output_operand_type}));
15171547
return Status::OK();

onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,42 @@ static bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, con
272272
return true;
273273
}
274274

275+
// Some Quantized NNAPI operations have required output scale and zero point
276+
// e.g. Softmax (uint8) requires output scale be 1.f/256 and zp be 0
277+
// This helper function checks if the given io_def has required scale and zp
278+
static bool HasRequiredScaleAndZeroPoint(const InitializedTensorSet& initializers,
279+
const std::string& op_desc,
280+
const NodeUnitIODef& io_def,
281+
const Path& path,
282+
float required_scale, int32_t required_zp) {
283+
float scale = 0.0f;
284+
int32_t zp = 0;
285+
auto status = GetQuantizationScaleAndZeroPoint(initializers, io_def, path,
286+
scale, zp);
287+
if (!status.IsOK()) {
288+
LOGS_DEFAULT(ERROR) << op_desc
289+
<< " GetQuantizationScaleAndZeroPoint failed, message: "
290+
<< status.ErrorMessage();
291+
return false;
292+
}
293+
294+
if (scale != required_scale) {
295+
LOGS_DEFAULT(VERBOSE) << op_desc
296+
<< " scale can only be [" << required_scale
297+
<< "], actual scale: " << scale;
298+
return false;
299+
}
300+
301+
if (zp != required_zp) {
302+
LOGS_DEFAULT(VERBOSE) << op_desc
303+
<< "] zero point can only be [" << required_zp
304+
<< "], actual zero point: " << scale;
305+
return false;
306+
}
307+
308+
return true;
309+
}
310+
275311
#pragma endregion helpers
276312

277313
#pragma region op_base
@@ -1142,8 +1178,19 @@ class SoftMaxOpSupportChecker : public BaseOpSupportChecker {
11421178
const OpSupportCheckParams& /* params */) const override {
11431179
return ANEURALNETWORKS_FEATURE_LEVEL_2;
11441180
}
1181+
bool HasSupportedInputOutputsImpl(
1182+
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
1183+
const OpSupportCheckParams& params) const override;
1184+
1185+
bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; }
1186+
1187+
bool IsQuantizedOp(const NodeUnit& node_unit) const override;
11451188
};
11461189

1190+
bool SoftMaxOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const {
1191+
return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQSoftmax;
1192+
}
1193+
11471194
bool SoftMaxOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
11481195
const OpSupportCheckParams& params) const {
11491196
Shape input_shape;
@@ -1171,6 +1218,32 @@ bool SoftMaxOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* i
11711218
return true;
11721219
}
11731220

1221+
bool SoftMaxOpSupportChecker::HasSupportedInputOutputsImpl(
1222+
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
1223+
const OpSupportCheckParams& params) const {
1224+
if (!IsQuantizedOp(node_unit)) {
1225+
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(initializers, node_unit, params);
1226+
}
1227+
1228+
if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, IOKind::Input)) {
1229+
return false;
1230+
}
1231+
1232+
if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, IOKind::Output)) {
1233+
return false;
1234+
}
1235+
1236+
// NNAPI requires the scale be 1.f/256 and zero point to be 0
1237+
if (!HasRequiredScaleAndZeroPoint(initializers,
1238+
MakeString("Op [", node_unit.OpType(), "] name [", node_unit.Name(), "]'s output 0 "),
1239+
node_unit.Outputs()[0], node_unit.ModelPath(),
1240+
1.f / 256 /* required_scale */, 0 /* required_zp */)) {
1241+
return false;
1242+
}
1243+
1244+
return true;
1245+
}
1246+
11741247
#pragma endregion
11751248

11761249
#pragma region op_gemm
@@ -1443,29 +1516,13 @@ int UnaryOpSupportChecker::GetMinSupportedOpSet(const NodeUnit& node_unit) const
14431516
const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) {
14441517
const auto& op_type = node_unit.OpType();
14451518
ORT_ENFORCE(op_type == "QLinearSigmoid");
1446-
const auto& op_name = node_unit.Name();
14471519

14481520
// NNAPI requires the scale be 1.f/256 and zero point to be 0
14491521
// See https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/android10-c2f2-release/nn/common/operations/Activation.cpp#180
1450-
float output_scale = 0.0f;
1451-
int32_t output_zp = 0;
1452-
auto status = GetQuantizationScaleAndZeroPoint(initializers, node_unit.Outputs()[0], node_unit.ModelPath(),
1453-
output_scale, output_zp);
1454-
if (!status.IsOK()) {
1455-
LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name
1456-
<< "] GetQuantizationScaleAndZeroPoint failed, message: " << status.ErrorMessage();
1457-
return false;
1458-
}
1459-
1460-
if (output_scale != 1.f / 256) {
1461-
LOGS_DEFAULT(VERBOSE) << "Op [" << op_type << "] name [" << op_name
1462-
<< "] output scale can only be 1.f/256, actual scale: " << output_scale;
1463-
return false;
1464-
}
1465-
1466-
if (output_zp != 0) {
1467-
LOGS_DEFAULT(VERBOSE) << "Op [" << op_type << "] name [" << op_name
1468-
<< "] output zero point can only be 0, actual zero point: " << output_scale;
1522+
if (!HasRequiredScaleAndZeroPoint(initializers,
1523+
MakeString("Op [", op_type, "] name [", node_unit.Name(), "]'s output 0 "),
1524+
node_unit.Outputs()[0], node_unit.ModelPath(),
1525+
1.f / 256 /* required_scale */, 0 /* required_zp */)) {
14691526
return false;
14701527
}
14711528

onnxruntime/test/optimizer/qdq_test_utils.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ GetQDQTestCaseFn BuildQDQTransposeTestCase(
193193
const std::vector<int64_t>& input_shape,
194194
const std::vector<int64_t>& perms) {
195195
return [input_shape, perms](ModelTestBuilder& builder) {
196-
auto* input_arg = builder.MakeInput<InputType>(input_shape, -128, 127);
196+
auto* input_arg = builder.MakeInput<InputType>(input_shape,
197+
std::numeric_limits<InputType>::min(),
198+
std::numeric_limits<InputType>::max());
197199
auto* output_arg = builder.MakeOutput();
198200

199201
InputType dq_zp = std::numeric_limits<InputType>::max() / 2;
@@ -215,5 +217,30 @@ GetQDQTestCaseFn BuildQDQTransposeTestCase(
215217

216218
GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector<int64_t>& input_shape,
217219
const std::vector<int64_t>& reshape_shape);
220+
221+
template <typename InputType, typename OutputType>
222+
GetQDQTestCaseFn BuildQDQSoftMaxTestCase(const std::vector<int64_t>& input_shape, const int64_t& axis = -1) {
223+
return [input_shape, axis](ModelTestBuilder& builder) {
224+
auto* input_arg = builder.MakeInput<InputType>(input_shape,
225+
std::numeric_limits<InputType>::min(),
226+
std::numeric_limits<InputType>::max());
227+
228+
auto* output_arg = builder.MakeOutput();
229+
230+
// add DQ
231+
auto* dq_output = builder.MakeIntermediate();
232+
builder.AddDequantizeLinearNode<InputType>(input_arg, .003f, std::numeric_limits<InputType>::max() / 2, dq_output);
233+
234+
// add SoftMax
235+
auto* softmax_output = builder.MakeIntermediate();
236+
Node& softmax_node = builder.AddNode("Softmax", {dq_output}, {softmax_output});
237+
238+
softmax_node.AddAttribute("axis", axis);
239+
240+
// add Q
241+
builder.AddQuantizeLinearNode<OutputType>(softmax_output, 1.f / 256, 0, output_arg);
242+
};
243+
}
244+
218245
} // namespace test
219246
} // namespace onnxruntime

onnxruntime/test/providers/nnapi/nnapi_basic_test.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,8 @@ TEST(NnapiExecutionProviderTest, TestQDQConv) {
306306
uint8_t /* WeightType */,
307307
int32_t /* BiasType */,
308308
uint8_t /* OutputType */>(
309-
{1, 1, 5, 5} /*input_shape*/,
310-
{1, 1, 3, 3} /*weights_shape*/),
309+
{1, 1, 5, 5} /* input_shape */,
310+
{1, 1, 3, 3} /* weights_shape */),
311311
"nnapi_qdq_test_graph_conv",
312312
{true /* verify_entire_graph_use_ep */});
313313
}
@@ -384,6 +384,16 @@ TEST(NnapiExecutionProviderTest, TestQDQReshape) {
384384
});
385385
}
386386

387+
TEST(NnapiExecutionProviderTest, TestQDQSoftMax) {
388+
RunQDQModelTest(BuildQDQSoftMaxTestCase<uint8_t, uint8_t>(
389+
{1, 32} /* input_shape */,
390+
static_cast<int64_t>(1) /* axis */),
391+
"nnapi_qdq_test_graph_softmax",
392+
{
393+
true /* verify_entire_graph_use_ep */
394+
});
395+
}
396+
387397
#endif // !(ORT_MINIMAL_BUILD)
388398

389399
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {

0 commit comments

Comments
 (0)